diff --git a/cmd/crowdsec-cli/clinotifications/notifications.go b/cmd/crowdsec-cli/clinotifications/notifications.go index d0f2234d856..d7fa1ef27a7 100644 --- a/cmd/crowdsec-cli/clinotifications/notifications.go +++ b/cmd/crowdsec-cli/clinotifications/notifications.go @@ -4,7 +4,6 @@ import ( "context" "encoding/csv" "encoding/json" - "errors" "fmt" "io/fs" "net/url" @@ -19,7 +18,6 @@ import ( "github.com/go-openapi/strfmt" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "gopkg.in/tomb.v2" "gopkg.in/yaml.v3" "github.com/crowdsecurity/go-cs-lib/ptr" @@ -275,7 +273,6 @@ func (cli *cliNotifications) notificationConfigFilter(_ *cobra.Command, args []s func (cli *cliNotifications) newTestCmd() *cobra.Command { var ( pluginBroker csplugin.PluginBroker - pluginTomb tomb.Tomb alertOverride string ) @@ -315,11 +312,19 @@ func (cli *cliNotifications) newTestCmd() *cobra.Command { }, }, cfg.ConfigPaths) }, - RunE: func(_ *cobra.Command, _ []string) error { - pluginTomb.Go(func() error { - pluginBroker.Run(&pluginTomb) - return nil - }) + RunE: func(cmd *cobra.Command, _ []string) error { + pluginCtx, cancelPlugin := context.WithCancel(cmd.Context()) + pluginDone := make(chan struct{}) + + go func() { + pluginBroker.Run(pluginCtx) + close(pluginDone) + }() + defer func() { + cancelPlugin() + <-pluginDone + }() + alert := &models.Alert{ Capacity: ptr.Of(int32(0)), Decisions: []*models.Decision{{ @@ -360,16 +365,12 @@ func (cli *cliNotifications) newTestCmd() *cobra.Command { Alert: alert, } - // time.Sleep(2 * time.Second) // There's no mechanism to ensure notification has been sent - pluginTomb.Kill(errors.New("terminating")) - _ = pluginTomb.Wait() - return nil }, } cmd.Flags().StringVarP(&alertOverride, "alert", "a", "", - "JSON string used to override alert fields in the generic alert " + - "(see crowdsec/pkg/models/alert.go in the source tree for the full definition of the object)") + "JSON string used to override alert fields in the generic alert "+ + "(see crowdsec/pkg/models/alert.go in the source tree for the full definition of the object)") return cmd } @@ -401,10 +402,7 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not return nil }, RunE: func(cmd *cobra.Command, _ []string) error { - var ( - pluginBroker csplugin.PluginBroker - pluginTomb tomb.Tomb - ) + var pluginBroker csplugin.PluginBroker ctx := cmd.Context() cfg := cli.cfg() @@ -442,10 +440,17 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not return fmt.Errorf("can't initialize plugins: %w", err) } - pluginTomb.Go(func() error { - pluginBroker.Run(&pluginTomb) - return nil - }) + pluginCtx, cancelPlugin := context.WithCancel(ctx) + pluginDone := make(chan struct{}) + + go func() { + pluginBroker.Run(pluginCtx) + close(pluginDone) + }() + defer func() { + cancelPlugin() + <-pluginDone + }() profiles, err := csprofiles.NewProfile(cfg.API.Server.Profiles) if err != nil { @@ -481,16 +486,12 @@ cscli notifications reinject -a '{"remediation": true,"scenario":"not break } } - // time.Sleep(2 * time.Second) // There's no mechanism to ensure notification has been sent - pluginTomb.Kill(errors.New("terminating")) - _ = pluginTomb.Wait() - return nil }, } cmd.Flags().StringVarP(&alertOverride, "alert", "a", "", - "JSON string used to override alert fields in the reinjected alert " + - "(see crowdsec/pkg/models/alert.go in the source tree for the full definition of the object)") + "JSON string used to override alert fields in the reinjected alert "+ + "(see crowdsec/pkg/models/alert.go in the source tree for the full definition of the object)") return cmd } diff --git a/cmd/crowdsec/api.go b/cmd/crowdsec/api.go index 25f5ff03988..c27fb98404e 100644 --- a/cmd/crowdsec/api.go +++ b/cmd/crowdsec/api.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "net" "time" log "github.com/sirupsen/logrus" @@ -20,6 +21,15 @@ func initAPIServer(ctx context.Context, cConfig *csconfig.Config) (*apiserver.AP log.Info("push and pull to Central API disabled") } + if cConfig.API.Server.ListenURI != "" { + listenConfig := &net.ListenConfig{} + listener, err := listenConfig.Listen(ctx, "tcp", cConfig.API.Server.ListenURI) + if err != nil { + return nil, fmt.Errorf("local API server stopped with error: listening on %s: %w", cConfig.API.Server.ListenURI, err) + } + _ = listener.Close() + } + accessLogger := cConfig.API.Server.NewAccessLogger(cConfig.Common.LogConfig, accessLogFilename) apiServer, err := apiserver.NewServer(ctx, cConfig.API.Server, accessLogger) @@ -42,30 +52,62 @@ func initAPIServer(ctx context.Context, cConfig *csconfig.Config) (*apiserver.AP func serveAPIServer(ctx context.Context, apiServer *apiserver.APIServer) { apiReady := make(chan bool, 1) + runCtx, cancel := context.WithCancel(ctx) + apiCancel = cancel + apiDone = make(chan error, 1) - apiTomb.Go(func() error { + go func() { defer trace.ReportPanic() + pluginCtx, cancelPlugin := context.WithCancel(runCtx) + pluginDone := make(chan struct{}) + + go func() { + defer trace.ReportPanic() + defer close(pluginDone) + pluginBroker.Run(pluginCtx) + }() + + runErr := make(chan error, 1) + go func() { defer trace.ReportPanic() log.Debugf("serving API after %s ms", time.Since(crowdsecT0)) - if err := apiServer.Run(ctx, apiReady); err != nil { - log.Fatal(err) - } + runErr <- apiServer.Run(runCtx, apiReady) }() - pluginTomb.Go(func() error { - pluginBroker.Run(&pluginTomb) - return nil - }) - - <-apiTomb.Dying() // lock until go routine is dying - pluginTomb.Kill(nil) - log.Infof("serve: shutting down api server") + select { + case err := <-runErr: + if err != nil && runCtx.Err() == nil { + log.Fatal(err) + } + cancelPlugin() + <-pluginDone + apiDone <- err + case <-runCtx.Done(): + log.Infof("serve: shutting down api server") + + shutdownCtx, cancelShutdown := context.WithTimeout(context.WithoutCancel(runCtx), 5*time.Second) + err := apiServer.Shutdown(shutdownCtx) + cancelShutdown() + + if runErrValue := <-runErr; runErrValue != nil && err == nil { + err = runErrValue + } - return apiServer.Shutdown(ctx) - }) - <-apiReady + cancelPlugin() + <-pluginDone + apiDone <- err + } + }() + + select { + case <-apiReady: + case err := <-apiDone: + if err != nil { + log.Fatal(err) + } + } } diff --git a/cmd/crowdsec/crowdsec.go b/cmd/crowdsec/crowdsec.go index 968d8dd556c..52ca463a26a 100644 --- a/cmd/crowdsec/crowdsec.go +++ b/cmd/crowdsec/crowdsec.go @@ -185,38 +185,37 @@ func serveCrowdsec( sd *StateDumper, ) { cctx, cancel := context.WithCancel(ctx) + crowdCancel = cancel + crowdDone = make(chan error, 1) var g errgroup.Group bucketStore := leakybucket.NewBucketStore() - crowdsecTomb.Go(func() error { + go func() { defer trace.ReportPanic() + // this logs every time, even at config reload + log.Debugf("running agent after %s ms", time.Since(crowdsecT0)) + agentReady <- true - go func() { - defer trace.ReportPanic() - // this logs every time, even at config reload - log.Debugf("running agent after %s ms", time.Since(crowdsecT0)) - - agentReady <- true - - if err := runCrowdsec(cctx, &g, cConfig, parsers, hub, datasources, sd, bucketStore); err != nil { - log.Fatalf("unable to start crowdsec routines: %s", err) - } - }() + if err := runCrowdsec(cctx, &g, cConfig, parsers, hub, datasources, sd, bucketStore); err != nil { + crowdDone <- fmt.Errorf("unable to start crowdsec routines: %w", err) + return + } /* we should stop in two cases : - - crowdsecTomb has been Killed() : it might be shutdown or reload, so stop + - context has been canceled: it might be shutdown or reload, so stop - acquisTomb is dead, it means that we were in "cat" mode and files are done reading, quit */ - waitOnTomb() + waitOnCrowdsecStop(cctx) log.Debugf("Shutting down crowdsec routines") if err := ShutdownCrowdsecRoutines(cancel, &g, datasources); err != nil { - return fmt.Errorf("unable to shutdown crowdsec routines: %w", err) + crowdDone <- fmt.Errorf("unable to shutdown crowdsec routines: %w", err) + return } - log.Debugf("everything is dead, return crowdsecTomb") + log.Debugf("everything is dead, return crowdsec routine") log.Debugf("sd.DumpDir == %s", sd.DumpDir) if sd.DumpDir != "" { @@ -229,11 +228,11 @@ func serveCrowdsec( os.Exit(0) } - return nil - }) + crowdDone <- nil + }() } -func waitOnTomb() { +func waitOnCrowdsecStop(ctx context.Context) { for { select { case <-acquisTomb.Dead(): @@ -256,7 +255,7 @@ func waitOnTomb() { return - case <-crowdsecTomb.Dying(): + case <-ctx.Done(): log.Infof("Crowdsec engine shutting down") return } diff --git a/cmd/crowdsec/main.go b/cmd/crowdsec/main.go index 9480e6dd08e..9b47c30335f 100644 --- a/cmd/crowdsec/main.go +++ b/cmd/crowdsec/main.go @@ -30,12 +30,13 @@ import ( ) var ( - // tombs for the parser, buckets and outputs. - acquisTomb tomb.Tomb - outputsTomb tomb.Tomb - apiTomb tomb.Tomb - crowdsecTomb tomb.Tomb - pluginTomb tomb.Tomb + // tombs for acquisition and output routines. + acquisTomb tomb.Tomb + outputsTomb tomb.Tomb + apiCancel context.CancelFunc + apiDone chan error + crowdCancel context.CancelFunc + crowdDone chan error flags Flags diff --git a/cmd/crowdsec/serve.go b/cmd/crowdsec/serve.go index e916daed4dd..c54d7a4e50f 100644 --- a/cmd/crowdsec/serve.go +++ b/cmd/crowdsec/serve.go @@ -27,12 +27,13 @@ import ( ) func reloadHandler(ctx context.Context, _ os.Signal) (*csconfig.Config, error) { - // re-initialize tombs + // re-initialize routines state acquisTomb = tomb.Tomb{} outputsTomb = tomb.Tomb{} - apiTomb = tomb.Tomb{} - crowdsecTomb = tomb.Tomb{} - pluginTomb = tomb.Tomb{} + apiCancel = nil + apiDone = nil + crowdCancel = nil + crowdDone = nil sd := NewStateDumper(flags.DumpDir) @@ -151,9 +152,6 @@ func ShutdownCrowdsecRoutines(cancel context.CancelFunc, g *errgroup.Group, data log.Debugf("buckets are done") log.Debugf("metrics are done") - // He's dead, Jim. - crowdsecTomb.Kill(nil) - // close the potential geoips reader we have to avoid leaking ressources on reload exprhelpers.GeoIPClose() @@ -161,11 +159,15 @@ func ShutdownCrowdsecRoutines(cancel context.CancelFunc, g *errgroup.Group, data } func shutdownAPI() error { - log.Debugf("shutting down api via Tomb") - apiTomb.Kill(nil) + log.Debugf("shutting down api via context") + if apiCancel != nil { + apiCancel() + } - if err := apiTomb.Wait(); err != nil { - return err + if apiDone != nil { + if err := <-apiDone; err != nil { + return err + } } log.Debugf("done") @@ -174,11 +176,17 @@ func shutdownAPI() error { } func shutdownCrowdsec() error { - log.Debugf("shutting down crowdsec via Tomb") - crowdsecTomb.Kill(nil) + log.Infof("Crowdsec engine shutting down") - if err := crowdsecTomb.Wait(); err != nil { - return err + log.Debugf("shutting down crowdsec via context") + if crowdCancel != nil { + crowdCancel() + } + + if crowdDone != nil { + if err := <-crowdDone; err != nil { + return err + } } log.Debugf("done") @@ -312,9 +320,10 @@ func Serve( ) error { acquisTomb = tomb.Tomb{} outputsTomb = tomb.Tomb{} - apiTomb = tomb.Tomb{} - crowdsecTomb = tomb.Tomb{} - pluginTomb = tomb.Tomb{} + apiCancel = nil + apiDone = nil + crowdCancel = nil + crowdDone = nil if cConfig.API.Server != nil && cConfig.API.Server.DbConfig != nil { dbCfg := cConfig.API.Server.DbConfig @@ -408,25 +417,22 @@ func Serve( return HandleSignals(ctx, cConfig) } - waitChans := make([]<-chan struct{}, 0) - if !cConfig.DisableAgent { - waitChans = append(waitChans, crowdsecTomb.Dead()) + if crowdDone != nil { + if err := <-crowdDone; err != nil { + return fmt.Errorf("crowdsec shutdown: %w", err) + } + } + log.Infof("crowdsec shutdown") } if !cConfig.DisableAPI { - waitChans = append(waitChans, apiTomb.Dead()) - } - - for _, ch := range waitChans { - <-ch - - switch ch { - case apiTomb.Dead(): - log.Infof("api shutdown") - case crowdsecTomb.Dead(): - log.Infof("crowdsec shutdown") + if apiDone != nil { + if err := <-apiDone; err != nil { + return fmt.Errorf("api shutdown: %w", err) + } } + log.Infof("api shutdown") } return nil diff --git a/pkg/acquisition/modules/appsec/run.go b/pkg/acquisition/modules/appsec/run.go index deb96fa714d..dc6ecb9008e 100644 --- a/pkg/acquisition/modules/appsec/run.go +++ b/pkg/acquisition/modules/appsec/run.go @@ -127,7 +127,14 @@ func (w *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Eve if err != nil { w.logger.Errorf("failed to fetch allowlists for appsec, disabling them: %s", err) } else { - w.appsecAllowlistClient.StartRefresh(ctx, t) + refreshCtx, refreshCancel := context.WithCancel(ctx) + t.Go(func() error { + <-t.Dying() + refreshCancel() + return nil + }) + + w.appsecAllowlistClient.StartRefresh(refreshCtx) } t.Go(func() error { diff --git a/pkg/acquisition/modules/file/file_test.go b/pkg/acquisition/modules/file/file_test.go index faf2d1888cc..a837f76e368 100644 --- a/pkg/acquisition/modules/file/file_test.go +++ b/pkg/acquisition/modules/file/file_test.go @@ -1,10 +1,12 @@ package fileacquisition_test import ( + "context" "fmt" "os" "path/filepath" "runtime" + "strings" "sync/atomic" "testing" "time" @@ -13,7 +15,6 @@ import ( "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -340,8 +341,9 @@ force_inotify: true`, testPattern), subLogger := logger.WithField("type", fileacquisition.ModuleName) - tomb := tomb.Tomb{} out := make(chan pipeline.Event) + streamCtx, cancelStream := context.WithCancel(ctx) + streamErr := make(chan error, 1) f := fileacquisition.Source{} @@ -378,8 +380,13 @@ force_inotify: true`, testPattern), }() } - err = f.StreamingAcquisition(ctx, out, &tomb) - cstest.RequireErrorContains(t, err, tc.expectedErr) + go func() { + streamErr <- f.Stream(streamCtx, out) + }() + + if tc.expectedLines == 0 { + time.Sleep(200 * time.Millisecond) + } if tc.expectedLines != 0 { // f.IsTailing is path delimiter sensitive @@ -422,11 +429,18 @@ force_inotify: true`, testPattern), } if tc.expectedOutput != "" { - if hook.LastEntry() == nil { + found := false + for _, entry := range hook.AllEntries() { + if strings.Contains(entry.Message, tc.expectedOutput) { + found = true + break + } + } + + if !found { t.Fatalf("expected output %s, but got nothing", tc.expectedOutput) } - assert.Contains(t, hook.LastEntry().Message, tc.expectedOutput) hook.Reset() } @@ -434,7 +448,9 @@ force_inotify: true`, testPattern), tc.teardown() } - tomb.Kill(nil) + cancelStream() + err = <-streamErr + cstest.RequireErrorContains(t, err, tc.expectedErr) }) } } @@ -481,11 +497,14 @@ mode: tail // Create channel for events eventChan := make(chan pipeline.Event) - tomb := tomb.Tomb{} + streamCtx, cancelStream := context.WithCancel(ctx) + defer cancelStream() + streamErr := make(chan error, 1) // Start acquisition - err = f.StreamingAcquisition(ctx, eventChan, &tomb) - require.NoError(t, err) + go func() { + streamErr <- f.Stream(streamCtx, eventChan) + }() // Create a test file testFile := filepath.Join(dir, "test.log") @@ -503,8 +522,8 @@ mode: tail require.False(t, f.IsTailing(ignoredFile), "File should be ignored after polling") // Cleanup - tomb.Kill(nil) - require.NoError(t, tomb.Wait()) + cancelStream() + require.NoError(t, <-streamErr) } func TestFileResurrectionViaPolling(t *testing.T) { @@ -532,10 +551,13 @@ mode: tail require.NoError(t, err) eventChan := make(chan pipeline.Event) - tomb := tomb.Tomb{} + streamCtx, cancelStream := context.WithCancel(ctx) + defer cancelStream() + streamErr := make(chan error, 1) - err = f.StreamingAcquisition(ctx, eventChan, &tomb) - require.NoError(t, err) + go func() { + streamErr <- f.Stream(streamCtx, eventChan) + }() // Wait for initial tail setup time.Sleep(100 * time.Millisecond) @@ -553,6 +575,6 @@ mode: tail require.True(t, isTailed, "File should be resurrected via polling") // Cleanup - tomb.Kill(nil) - require.NoError(t, tomb.Wait()) + cancelStream() + require.NoError(t, <-streamErr) } diff --git a/pkg/acquisition/modules/file/init.go b/pkg/acquisition/modules/file/init.go index 42c54eba77f..2c63e8f48b3 100644 --- a/pkg/acquisition/modules/file/init.go +++ b/pkg/acquisition/modules/file/init.go @@ -7,11 +7,11 @@ import ( var ( // verify interface compliance - _ types.DataSource = (*Source)(nil) - _ types.DSNConfigurer = (*Source)(nil) - _ types.BatchFetcher = (*Source)(nil) - _ types.Tailer = (*Source)(nil) - _ types.MetricsProvider = (*Source)(nil) + _ types.DataSource = (*Source)(nil) + _ types.DSNConfigurer = (*Source)(nil) + _ types.BatchFetcher = (*Source)(nil) + _ types.RestartableStreamer = (*Source)(nil) + _ types.MetricsProvider = (*Source)(nil) ) const ModuleName = "file" diff --git a/pkg/acquisition/modules/file/run.go b/pkg/acquisition/modules/file/run.go index 37434b8739c..76062de7908 100644 --- a/pkg/acquisition/modules/file/run.go +++ b/pkg/acquisition/modules/file/run.go @@ -5,6 +5,7 @@ import ( "cmp" "compress/gzip" "context" + "errors" "fmt" "io" "os" @@ -16,7 +17,6 @@ import ( "github.com/nxadm/tail" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/trace" @@ -53,19 +53,46 @@ func (s *Source) OneShot(ctx context.Context, out chan pipeline.Event) error { return nil } -func (s *Source) StreamingAcquisition(_ context.Context, out chan pipeline.Event, t *tomb.Tomb) error { +func sendStreamErr(errCh chan<- error, err error) { + if err == nil { + return + } + + select { + case errCh <- err: + default: + } +} + +func (s *Source) Stream(ctx context.Context, out chan pipeline.Event) error { s.logger.Debug("Starting live acquisition") - t.Go(func() error { - return s.monitorNewFiles(out, t) - }) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + errCh := make(chan error, 1) + go func() { + sendStreamErr(errCh, s.monitorNewFiles(ctx, out, errCh)) + }() for _, file := range s.files { - if err := s.setupTailForFile(file, out, true, t); err != nil { + if err := s.setupTailForFile(ctx, file, out, true, errCh); err != nil { s.logger.Errorf("Error setting up tail for %s: %s", file, err) } } - return nil + for { + select { + case <-ctx.Done(): + return nil + case err := <-errCh: + if err == nil || errors.Is(err, context.Canceled) || ctx.Err() != nil { + return nil + } + + return err + } + } } // checkAndTailFile validates and sets up tailing for a given file. It performs the following checks: @@ -77,10 +104,10 @@ func (s *Source) StreamingAcquisition(_ context.Context, out chan pipeline.Event // - filename: The path to the file to check and potentially tail // - logger: A log.Entry for contextual logging // - out: Channel to send file events to -// - t: A tomb.Tomb for graceful shutdown handling +// - ctx: context used for graceful shutdown handling // // Returns an error if any validation fails or if tailing setup fails -func (s *Source) checkAndTailFile(filename string, logger *log.Entry, out chan pipeline.Event, t *tomb.Tomb) error { +func (s *Source) checkAndTailFile(ctx context.Context, filename string, logger *log.Entry, out chan pipeline.Event, errCh chan<- error) error { // Check if it's a directory fi, err := os.Stat(filename) if err != nil { @@ -117,7 +144,7 @@ func (s *Source) checkAndTailFile(filename string, logger *log.Entry, out chan p } // Setup the tail if needed - if err := s.setupTailForFile(filename, out, false, t); err != nil { + if err := s.setupTailForFile(ctx, filename, out, false, errCh); err != nil { logger.Errorf("Error setting up tail for file %s: %s", filename, err) return err } @@ -125,13 +152,13 @@ func (s *Source) checkAndTailFile(filename string, logger *log.Entry, out chan p return nil } -func (s *Source) monitorNewFiles(out chan pipeline.Event, t *tomb.Tomb) error { +func (s *Source) monitorNewFiles(ctx context.Context, out chan pipeline.Event, errCh chan<- error) error { logger := s.logger.WithField("goroutine", "inotify") // Setup polling if enabled var ( tickerChan <-chan time.Time - ticker *time.Ticker + ticker *time.Ticker ) if s.config.DiscoveryPollEnable { @@ -154,7 +181,7 @@ func (s *Source) monitorNewFiles(out chan pipeline.Event, t *tomb.Tomb) error { continue } - _ = s.checkAndTailFile(event.Name, logger, out, t) + _ = s.checkAndTailFile(ctx, event.Name, logger, out, errCh) case <-tickerChan: // Will never trigger if tickerChan is nil // Poll for all configured patterns @@ -166,7 +193,7 @@ func (s *Source) monitorNewFiles(out chan pipeline.Event, t *tomb.Tomb) error { } for _, file := range files { - _ = s.checkAndTailFile(file, logger, out, t) + _ = s.checkAndTailFile(ctx, file, logger, out, errCh) } } @@ -177,7 +204,7 @@ func (s *Source) monitorNewFiles(out chan pipeline.Event, t *tomb.Tomb) error { logger.Errorf("Error while monitoring folder: %s", err) - case <-t.Dying(): + case <-ctx.Done(): err := s.watcher.Close() if err != nil { return fmt.Errorf("could not remove all inotify watches: %w", err) @@ -188,7 +215,7 @@ func (s *Source) monitorNewFiles(out chan pipeline.Event, t *tomb.Tomb) error { } } -func (s *Source) setupTailForFile(file string, out chan pipeline.Event, seekEnd bool, t *tomb.Tomb) error { +func (s *Source) setupTailForFile(ctx context.Context, file string, out chan pipeline.Event, seekEnd bool, errCh chan<- error) error { logger := s.logger.WithField("file", file) if s.isExcluded(file) { @@ -283,21 +310,21 @@ func (s *Source) setupTailForFile(file string, out chan pipeline.Event, seekEnd s.tails[file] = true s.tailMapMutex.Unlock() - t.Go(func() error { + go func() { defer trace.ReportPanic() - return s.tailFile(out, t, tail) - }) + sendStreamErr(errCh, s.tailFile(ctx, out, tail)) + }() return nil } -func (s *Source) tailFile(out chan pipeline.Event, t *tomb.Tomb, tail *tail.Tail) error { +func (s *Source) tailFile(ctx context.Context, out chan pipeline.Event, tail *tail.Tail) error { logger := s.logger.WithField("tail", tail.Filename) logger.Debug("-> start tailing") for { select { - case <-t.Dying(): + case <-ctx.Done(): logger.Info("File datasource stopping") if err := tail.Stop(); err != nil { diff --git a/pkg/acquisition/modules/http/http_test.go b/pkg/acquisition/modules/http/http_test.go index 12f1c959e62..67c6617e39c 100644 --- a/pkg/acquisition/modules/http/http_test.go +++ b/pkg/acquisition/modules/http/http_test.go @@ -20,7 +20,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/metrics" "github.com/crowdsecurity/crowdsec/pkg/pipeline" @@ -56,7 +55,7 @@ func TestGetName(t *testing.T) { assert.Equal(t, "http", h.GetName()) } -func SetupAndRunHTTPSource(t *testing.T, h *Source, config []byte, metricLevel metrics.AcquisitionMetricsLevel) (chan pipeline.Event, *prometheus.Registry, *tomb.Tomb) { +func SetupAndRunHTTPSource(t *testing.T, h *Source, config []byte, metricLevel metrics.AcquisitionMetricsLevel) (chan pipeline.Event, *prometheus.Registry, context.CancelFunc, chan error) { ctx := t.Context() subLogger := log.WithFields(log.Fields{ "type": ModuleName, @@ -64,10 +63,12 @@ func SetupAndRunHTTPSource(t *testing.T, h *Source, config []byte, metricLevel m err := h.Configure(ctx, config, subLogger, metricLevel) require.NoError(t, err) - tomb := tomb.Tomb{} + streamCtx, cancel := context.WithCancel(ctx) + streamErr := make(chan error, 1) out := make(chan pipeline.Event) - err = h.StreamingAcquisition(ctx, out, &tomb) - require.NoError(t, err) + go func() { + streamErr <- h.Stream(streamCtx, out) + }() testRegistry := prometheus.NewPedanticRegistry() for _, metric := range h.GetMetrics() { @@ -75,12 +76,22 @@ func SetupAndRunHTTPSource(t *testing.T, h *Source, config []byte, metricLevel m require.NoError(t, err) } - return out, testRegistry, &tomb + return out, testRegistry, cancel, streamErr +} + +func stopHTTPSource(t *testing.T, h *Source, cancel context.CancelFunc, streamErr <-chan error) { + t.Helper() + + if h.Server != nil { + _ = h.Server.Close() + } + cancel() + require.NoError(t, <-streamErr) } func TestStreamingAcquisitionHTTPMethod(t *testing.T) { h := &Source{} - _, _, tomb := SetupAndRunHTTPSource(t, h, []byte(` + _, _, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -122,17 +133,14 @@ basic_auth: assert.Equal(t, http.StatusOK, res.StatusCode) closeBody(t, res) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionUnknownPath(t *testing.T) { ctx := t.Context() h := &Source{} - _, _, tomb := SetupAndRunHTTPSource(t, h, []byte(` + _, _, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -151,16 +159,13 @@ basic_auth: assert.Equal(t, http.StatusNotFound, res.StatusCode) closeBody(t, res) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionBasicAuth(t *testing.T) { ctx := t.Context() h := &Source{} - _, _, tomb := SetupAndRunHTTPSource(t, h, []byte(` + _, _, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -191,16 +196,13 @@ basic_auth: assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) closeBody(t, resp) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionBadHeaders(t *testing.T) { ctx := t.Context() h := &Source{} - _, _, tomb := SetupAndRunHTTPSource(t, h, []byte(` + _, _, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -221,16 +223,13 @@ headers: assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) closeBody(t, resp) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionMaxBodySize(t *testing.T) { ctx := t.Context() h := &Source{} - _, _, tomb := SetupAndRunHTTPSource(t, h, []byte(` + _, _, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -252,16 +251,13 @@ max_body_size: 5`), 0) assert.Equal(t, http.StatusRequestEntityTooLarge, resp.StatusCode) closeBody(t, resp) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionSuccess(t *testing.T) { ctx := t.Context() h := &Source{} - out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(` + out, reg, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -291,16 +287,13 @@ headers: assertMetrics(t, reg, h.GetMetrics(), 1) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionCustomStatusCodeAndCustomHeaders(t *testing.T) { ctx := t.Context() h := &Source{} - out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(` + out, reg, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -334,10 +327,7 @@ custom_headers: assertMetrics(t, reg, h.GetMetrics(), 1) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestAcquistionSocket(t *testing.T) { @@ -346,7 +336,7 @@ func TestAcquistionSocket(t *testing.T) { ctx := t.Context() h := &Source{} - out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(` + out, reg, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_socket: `+socketFile+` path: /test @@ -382,10 +372,7 @@ headers: assertMetrics(t, reg, h.GetMetrics(), 1) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } type slowReader struct { @@ -447,7 +434,7 @@ func assertEvents(out chan pipeline.Event, expected []string, errChan chan error func TestStreamingAcquisitionTimeout(t *testing.T) { ctx := t.Context() h := &Source{} - _, _, tomb := SetupAndRunHTTPSource(t, h, []byte(` + _, _, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -476,17 +463,14 @@ timeout: 1s`), 0) assert.Equal(t, http.StatusBadRequest, resp.StatusCode) closeBody(t, resp) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionTLSHTTPRequest(t *testing.T) { ctx := t.Context() h := &Source{} - _, _, tomb := SetupAndRunHTTPSource(t, h, []byte(` + _, _, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 auth_type: mtls @@ -508,16 +492,13 @@ tls: assert.Equal(t, http.StatusBadRequest, resp.StatusCode) closeBody(t, resp) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionTLSWithHeadersAuthSuccess(t *testing.T) { ctx := t.Context() h := &Source{} - out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(` + out, reg, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -567,16 +548,13 @@ tls: assertMetrics(t, reg, h.GetMetrics(), 0) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionMTLS(t *testing.T) { ctx := t.Context() h := &Source{} - out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(` + out, reg, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -628,16 +606,13 @@ tls: assertMetrics(t, reg, h.GetMetrics(), 0) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionGzipData(t *testing.T) { ctx := t.Context() h := &Source{} - out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(` + out, reg, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -684,16 +659,13 @@ headers: assertMetrics(t, reg, h.GetMetrics(), 2) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func TestStreamingAcquisitionNDJson(t *testing.T) { ctx := t.Context() h := &Source{} - out, reg, tomb := SetupAndRunHTTPSource(t, h, []byte(` + out, reg, cancel, streamErr := SetupAndRunHTTPSource(t, h, []byte(` source: http listen_addr: 127.0.0.1:8080 path: /test @@ -726,10 +698,7 @@ headers: assertMetrics(t, reg, h.GetMetrics(), 2) - h.Server.Close() - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + stopHTTPSource(t, h, cancel, streamErr) } func assertMetrics(t *testing.T, reg *prometheus.Registry, metrics []prometheus.Collector, expected int) { diff --git a/pkg/acquisition/modules/http/init.go b/pkg/acquisition/modules/http/init.go index 07d19e48007..17191d0b25d 100644 --- a/pkg/acquisition/modules/http/init.go +++ b/pkg/acquisition/modules/http/init.go @@ -7,9 +7,9 @@ import ( var ( // verify interface compliance - _ types.DataSource = (*Source)(nil) - _ types.Tailer = (*Source)(nil) - _ types.MetricsProvider = (*Source)(nil) + _ types.DataSource = (*Source)(nil) + _ types.RestartableStreamer = (*Source)(nil) + _ types.MetricsProvider = (*Source)(nil) ) const ModuleName = "http" diff --git a/pkg/acquisition/modules/http/run.go b/pkg/acquisition/modules/http/run.go index ce548917333..1a2cee12ea2 100644 --- a/pkg/acquisition/modules/http/run.go +++ b/pkg/acquisition/modules/http/run.go @@ -15,7 +15,6 @@ import ( "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/trace" @@ -130,59 +129,63 @@ func (s *Source) processRequest(w http.ResponseWriter, r *http.Request, hc *Conf return nil } -func (s *Source) RunServer(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { - mux := http.NewServeMux() - mux.HandleFunc(s.Config.Path, func(w http.ResponseWriter, r *http.Request) { - if err := authorizeRequest(r, &s.Config); err != nil { - s.logger.Errorf("failed to authorize request from '%s': %s", r.RemoteAddr, err) - http.Error(w, "Unauthorized", http.StatusUnauthorized) - - return - } +func (s *Source) handleStreamRequest(w http.ResponseWriter, r *http.Request, out chan pipeline.Event) { + if err := authorizeRequest(r, &s.Config); err != nil { + s.logger.Errorf("failed to authorize request from '%s': %s", r.RemoteAddr, err) + http.Error(w, "Unauthorized", http.StatusUnauthorized) - switch r.Method { - case http.MethodGet, http.MethodHead: // Return a 200 if the auth was successful - s.logger.Infof("successful %s request from '%s'", r.Method, r.RemoteAddr) - w.WriteHeader(http.StatusOK) + return + } - if _, err := w.Write([]byte("OK")); err != nil { - s.logger.Errorf("failed to write response: %v", err) - } + switch r.Method { + case http.MethodGet, http.MethodHead: // Return a 200 if the auth was successful + s.logger.Infof("successful %s request from '%s'", r.Method, r.RemoteAddr) + w.WriteHeader(http.StatusOK) - return - case http.MethodPost: // POST is handled below - default: - s.logger.Errorf("method not allowed: %s", r.Method) - http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) - return + if _, err := w.Write([]byte("OK")); err != nil { + s.logger.Errorf("failed to write response: %v", err) } - if r.RemoteAddr == "@" { - // We check if request came from unix socket and if so we set to loopback - r.RemoteAddr = "127.0.0.1:65535" - } + return + case http.MethodPost: // POST is handled below + default: + s.logger.Errorf("method not allowed: %s", r.Method) + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return + } - err := s.processRequest(w, r, &s.Config, out) - if err != nil { - s.logger.Errorf("failed to process request from '%s': %s", r.RemoteAddr, err) - return - } + if r.RemoteAddr == "@" { + // We check if request came from unix socket and if so we set to loopback + r.RemoteAddr = "127.0.0.1:65535" + } - if s.Config.CustomHeaders != nil { - for key, value := range s.Config.CustomHeaders { - w.Header().Set(key, value) - } - } + err := s.processRequest(w, r, &s.Config, out) + if err != nil { + s.logger.Errorf("failed to process request from '%s': %s", r.RemoteAddr, err) + return + } - if s.Config.CustomStatusCode != nil { - w.WriteHeader(*s.Config.CustomStatusCode) - } else { - w.WriteHeader(http.StatusOK) + if s.Config.CustomHeaders != nil { + for key, value := range s.Config.CustomHeaders { + w.Header().Set(key, value) } + } - if _, err := w.Write([]byte("OK")); err != nil { - s.logger.Errorf("failed to write response: %v", err) - } + if s.Config.CustomStatusCode != nil { + w.WriteHeader(*s.Config.CustomStatusCode) + } else { + w.WriteHeader(http.StatusOK) + } + + if _, err := w.Write([]byte("OK")); err != nil { + s.logger.Errorf("failed to write response: %v", err) + } +} + +func (s *Source) Stream(ctx context.Context, out chan pipeline.Event) error { + mux := http.NewServeMux() + mux.HandleFunc(s.Config.Path, func(w http.ResponseWriter, r *http.Request) { + s.handleStreamRequest(w, r, out) }) s.Server = &http.Server{ @@ -210,12 +213,13 @@ func (s *Source) RunServer(ctx context.Context, out chan pipeline.Event, t *tomb } listenConfig := &net.ListenConfig{} + serverErr := make(chan error, 2) - t.Go(func() error { + go func() { defer trace.ReportPanic() if s.Config.ListenSocket == "" { - return nil + return } s.logger.Infof("creating unix socket on %s", s.Config.ListenSocket) @@ -223,29 +227,38 @@ func (s *Source) RunServer(ctx context.Context, out chan pipeline.Event, t *tomb listener, err := listenConfig.Listen(ctx, "unix", s.Config.ListenSocket) if err != nil { - return csnet.WrapSockErr(err, s.Config.ListenSocket) + select { + case serverErr <- csnet.WrapSockErr(err, s.Config.ListenSocket): + default: + } + + return } if s.Config.TLS != nil { err := s.Server.ServeTLS(listener, s.Config.TLS.ServerCert, s.Config.TLS.ServerKey) if err != nil && err != http.ErrServerClosed { - return fmt.Errorf("https server failed: %w", err) + select { + case serverErr <- fmt.Errorf("https server failed: %w", err): + default: + } } } else { err := s.Server.Serve(listener) if err != nil && err != http.ErrServerClosed { - return fmt.Errorf("http server failed: %w", err) + select { + case serverErr <- fmt.Errorf("http server failed: %w", err): + default: + } } } + }() - return nil - }) - - t.Go(func() error { + go func() { defer trace.ReportPanic() if s.Config.ListenAddr == "" { - return nil + return } if s.Config.TLS != nil { @@ -253,38 +266,33 @@ func (s *Source) RunServer(ctx context.Context, out chan pipeline.Event, t *tomb err := s.Server.ListenAndServeTLS(s.Config.TLS.ServerCert, s.Config.TLS.ServerKey) if err != nil && err != http.ErrServerClosed { - return fmt.Errorf("https server failed: %w", err) + select { + case serverErr <- fmt.Errorf("https server failed: %w", err): + default: + } } } else { s.logger.Infof("start http server on %s", s.Config.ListenAddr) err := s.Server.ListenAndServe() if err != nil && err != http.ErrServerClosed { - return fmt.Errorf("http server failed: %w", err) + select { + case serverErr <- fmt.Errorf("http server failed: %w", err): + default: + } } } + }() - return nil - }) - - <-t.Dying() - - s.logger.Infof("%s datasource stopping", s.GetName()) + select { + case <-ctx.Done(): + s.logger.Infof("%s datasource stopping", s.GetName()) + if err := s.Server.Close(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("while closing %s server: %w", s.GetName(), err) + } - if err := s.Server.Close(); err != nil { - return fmt.Errorf("while closing %s server: %w", s.GetName(), err) + return nil + case err := <-serverErr: + return err } - - return nil -} - -func (s *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { - s.logger.Debugf("start http server on %s", s.Config.ListenAddr) - - t.Go(func() error { - defer trace.ReportPanic() - return s.RunServer(ctx, out, t) - }) - - return nil } diff --git a/pkg/acquisition/modules/kafka/init.go b/pkg/acquisition/modules/kafka/init.go index d5f75a0a920..bd29572da27 100644 --- a/pkg/acquisition/modules/kafka/init.go +++ b/pkg/acquisition/modules/kafka/init.go @@ -7,9 +7,9 @@ import ( var ( // verify interface compliance - _ types.DataSource = (*Source)(nil) - _ types.Tailer = (*Source)(nil) - _ types.MetricsProvider = (*Source)(nil) + _ types.DataSource = (*Source)(nil) + _ types.RestartableStreamer = (*Source)(nil) + _ types.MetricsProvider = (*Source)(nil) ) const ModuleName = "kafka" diff --git a/pkg/acquisition/modules/kafka/kafka_test.go b/pkg/acquisition/modules/kafka/kafka_test.go index 7072e95daab..7a6233289fe 100644 --- a/pkg/acquisition/modules/kafka/kafka_test.go +++ b/pkg/acquisition/modules/kafka/kafka_test.go @@ -10,7 +10,6 @@ import ( "github.com/segmentio/kafka-go" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -68,8 +67,6 @@ func createTopic(topic string, broker string) { func TestStreamingAcquisition(t *testing.T) { cstest.SetAWSTestEnv(t) - ctx := t.Context() - tests := []struct { name string logs []string @@ -101,6 +98,9 @@ func TestStreamingAcquisition(t *testing.T) { for _, ts := range tests { t.Run(ts.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + k := Source{} err := k.Configure(ctx, []byte(` @@ -112,11 +112,12 @@ topic: crowdsecplaintext`), subLogger, metrics.AcquisitionMetricsLevelNone) t.Fatalf("could not configure kafka source : %s", err) } - tomb := tomb.Tomb{} - out := make(chan pipeline.Event) - err = k.StreamingAcquisition(ctx, out, &tomb) - cstest.AssertErrorContains(t, err, ts.expectedErr) + + streamErr := make(chan error, 1) + go func() { + streamErr <- k.Stream(ctx, out) + }() actualLines := 0 @@ -132,9 +133,9 @@ topic: crowdsecplaintext`), subLogger, metrics.AcquisitionMetricsLevelNone) } require.Equal(t, ts.expectedLines, actualLines) - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + cancel() + err = <-streamErr + cstest.AssertErrorContains(t, err, ts.expectedErr) }) } } @@ -142,8 +143,6 @@ topic: crowdsecplaintext`), subLogger, metrics.AcquisitionMetricsLevelNone) func TestStreamingAcquisitionWithSSL(t *testing.T) { cstest.SetAWSTestEnv(t) - ctx := t.Context() - tests := []struct { name string logs []string @@ -174,6 +173,9 @@ func TestStreamingAcquisitionWithSSL(t *testing.T) { for _, ts := range tests { t.Run(ts.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + k := Source{} err := k.Configure(ctx, []byte(` @@ -191,10 +193,12 @@ tls: t.Fatalf("could not configure kafka source : %s", err) } - tomb := tomb.Tomb{} out := make(chan pipeline.Event) - err = k.StreamingAcquisition(ctx, out, &tomb) - cstest.AssertErrorContains(t, err, ts.expectedErr) + + streamErr := make(chan error, 1) + go func() { + streamErr <- k.Stream(ctx, out) + }() actualLines := 0 @@ -210,9 +214,9 @@ tls: } require.Equal(t, ts.expectedLines, actualLines) - tomb.Kill(nil) - err = tomb.Wait() - require.NoError(t, err) + cancel() + err = <-streamErr + cstest.AssertErrorContains(t, err, ts.expectedErr) }) } } diff --git a/pkg/acquisition/modules/kafka/run.go b/pkg/acquisition/modules/kafka/run.go index 8b69605016f..5c597dea2a4 100644 --- a/pkg/acquisition/modules/kafka/run.go +++ b/pkg/acquisition/modules/kafka/run.go @@ -8,9 +8,6 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/segmentio/kafka-go" - "gopkg.in/tomb.v2" - - "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/metrics" "github.com/crowdsecurity/crowdsec/pkg/pipeline" @@ -33,6 +30,10 @@ func (s *Source) ReadMessage(ctx context.Context, out chan pipeline.Event) error return nil } + if ctx.Err() != nil { + return nil //nolint:nilerr + } + s.logger.Errorln(fmt.Errorf("while reading %s message: %w", s.GetName(), err)) continue @@ -60,30 +61,16 @@ func (s *Source) ReadMessage(ctx context.Context, out chan pipeline.Event) error } } -func (s *Source) RunReader(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { - s.logger.Debugf("starting %s datasource reader goroutine with configuration %+v", s.GetName(), s.Config) - t.Go(func() error { - return s.ReadMessage(ctx, out) - }) - - <-t.Dying() - - s.logger.Infof("%s datasource topic %s stopping", s.GetName(), s.Config.Topic) - - if err := s.Reader.Close(); err != nil { - return fmt.Errorf("while closing %s reader on topic '%s': %w", s.GetName(), s.Config.Topic, err) - } - - return nil -} - -func (s *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { +func (s *Source) Stream(ctx context.Context, out chan pipeline.Event) error { s.logger.Infof("start reader on brokers '%+v' with topic '%s'", s.Config.Brokers, s.Config.Topic) - t.Go(func() error { - defer trace.ReportPanic() - return s.RunReader(ctx, out, t) - }) + defer func() { + s.logger.Infof("%s datasource topic %s stopping", s.GetName(), s.Config.Topic) + + if err := s.Reader.Close(); err != nil { + s.logger.Errorf("while closing %s reader on topic '%s': %s", s.GetName(), s.Config.Topic, err) + } + }() - return nil + return s.ReadMessage(ctx, out) } diff --git a/pkg/acquisition/modules/kubernetesaudit/init.go b/pkg/acquisition/modules/kubernetesaudit/init.go index cf2a149c391..bae7fc6ae8c 100644 --- a/pkg/acquisition/modules/kubernetesaudit/init.go +++ b/pkg/acquisition/modules/kubernetesaudit/init.go @@ -7,9 +7,9 @@ import ( var ( // verify interface compliance - _ types.DataSource = (*Source)(nil) - _ types.Tailer = (*Source)(nil) - _ types.MetricsProvider = (*Source)(nil) + _ types.DataSource = (*Source)(nil) + _ types.RestartableStreamer = (*Source)(nil) + _ types.MetricsProvider = (*Source)(nil) ) const ModuleName = "k8s-audit" diff --git a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go index df5b2be035c..eb581e7b91f 100644 --- a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go +++ b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go @@ -1,6 +1,7 @@ package kubernetesauditacquisition import ( + "context" "fmt" "net/http/httptest" "strings" @@ -10,7 +11,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -19,7 +19,6 @@ import ( ) func TestInvalidConfig(t *testing.T) { - ctx := t.Context() tests := []struct { name string config string @@ -39,8 +38,10 @@ webhook_path: /k8s-audit`, for _, test := range tests { t.Run(test.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + out := make(chan pipeline.Event) - tb := &tomb.Tomb{} f := Source{} @@ -51,19 +52,21 @@ webhook_path: /k8s-audit`, err = f.Configure(ctx, []byte(test.config), subLogger, metrics.AcquisitionMetricsLevelNone) require.NoError(t, err) - err = f.StreamingAcquisition(ctx, out, tb) - require.NoError(t, err) + + streamErr := make(chan error, 1) + go func() { + streamErr <- f.Stream(ctx, out) + }() time.Sleep(1 * time.Second) - tb.Kill(nil) - err = tb.Wait() + cancel() + err = <-streamErr cstest.RequireErrorContains(t, err, test.expectedErr) }) } } func TestHandler(t *testing.T) { - ctx := t.Context() tests := []struct { name string expectedStatusCode int @@ -185,20 +188,24 @@ func TestHandler(t *testing.T) { for idx, test := range tests { t.Run(test.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + out := make(chan pipeline.Event) - tb := &tomb.Tomb{} eventCount := 0 - tb.Go(func() error { + doneCounting := make(chan struct{}) + go func() { + defer close(doneCounting) for { select { case <-out: eventCount++ - case <-tb.Dying(): - return nil + case <-ctx.Done(): + return } } - }) + }() f := Source{} @@ -217,19 +224,21 @@ webhook_path: /k8s-audit`, port) req := httptest.NewRequest(test.method, "/k8s-audit", strings.NewReader(test.body)) w := httptest.NewRecorder() - err = f.StreamingAcquisition(ctx, out, tb) - require.NoError(t, err) + streamErr := make(chan error, 1) + go func() { + streamErr <- f.Stream(ctx, out) + }() f.webhookHandler(w, req) res := w.Result() assert.Equal(t, test.expectedStatusCode, res.StatusCode) - // time.Sleep(1 * time.Second) require.NoError(t, err) - tb.Kill(nil) - err = tb.Wait() + cancel() + <-doneCounting + err = <-streamErr require.NoError(t, err) assert.Equal(t, test.eventCount, eventCount) diff --git a/pkg/acquisition/modules/kubernetesaudit/run.go b/pkg/acquisition/modules/kubernetesaudit/run.go index e53b87f6fff..e711b3bf152 100644 --- a/pkg/acquisition/modules/kubernetesaudit/run.go +++ b/pkg/acquisition/modules/kubernetesaudit/run.go @@ -3,48 +3,45 @@ package kubernetesauditacquisition import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" "strings" "github.com/prometheus/client_golang/prometheus" - "gopkg.in/tomb.v2" "k8s.io/apiserver/pkg/apis/audit" - "github.com/crowdsecurity/go-cs-lib/trace" - "github.com/crowdsecurity/crowdsec/pkg/metrics" "github.com/crowdsecurity/crowdsec/pkg/pipeline" ) -func (s *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { +func (s *Source) Stream(ctx context.Context, out chan pipeline.Event) error { s.outChan = out - t.Go(func() error { - defer trace.ReportPanic() + s.logger.Infof("Starting k8s-audit server on %s:%d%s", s.config.ListenAddr, s.config.ListenPort, s.config.WebhookPath) - s.logger.Infof("Starting k8s-audit server on %s:%d%s", s.config.ListenAddr, s.config.ListenPort, s.config.WebhookPath) + serverErr := make(chan error, 1) - t.Go(func() error { - err := s.server.ListenAndServe() - if err != nil && err != http.ErrServerClosed { - return fmt.Errorf("k8s-audit server failed: %w", err) - } + go func() { + err := s.server.ListenAndServe() + if err != nil && !errors.Is(err, http.ErrServerClosed) { + serverErr <- fmt.Errorf("k8s-audit server failed: %w", err) + } + }() - return nil - }) - <-t.Dying() + select { + case <-ctx.Done(): s.logger.Infof("Stopping k8s-audit server on %s:%d%s", s.config.ListenAddr, s.config.ListenPort, s.config.WebhookPath) - if err := s.server.Shutdown(ctx); err != nil { + if err := s.server.Shutdown(context.Background()); err != nil { //nolint:contextcheck // shutdown needs a fresh context after parent cancellation s.logger.Errorf("Error shutting down k8s-audit server: %s", err.Error()) } return nil - }) - - return nil + case err := <-serverErr: + return err + } } func (s *Source) webhookHandler(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/acquisition/modules/loki/init.go b/pkg/acquisition/modules/loki/init.go index 2f77d1d9f49..18f94c44d3d 100644 --- a/pkg/acquisition/modules/loki/init.go +++ b/pkg/acquisition/modules/loki/init.go @@ -7,11 +7,11 @@ import ( var ( // verify interface compliance - _ types.DataSource = (*Source)(nil) - _ types.DSNConfigurer = (*Source)(nil) - _ types.Fetcher = (*Source)(nil) - _ types.Tailer = (*Source)(nil) - _ types.MetricsProvider = (*Source)(nil) + _ types.DataSource = (*Source)(nil) + _ types.DSNConfigurer = (*Source)(nil) + _ types.BatchFetcher = (*Source)(nil) + _ types.RestartableStreamer = (*Source)(nil) + _ types.MetricsProvider = (*Source)(nil) ) const ModuleName = "loki" diff --git a/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go index 83486c0c689..cf80988ff72 100644 --- a/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go +++ b/pkg/acquisition/modules/loki/internal/lokiclient/loki_client.go @@ -14,7 +14,6 @@ import ( "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" "maps" @@ -24,7 +23,6 @@ type LokiClient struct { Logger *log.Entry config Config - t *tomb.Tomb failStart time.Time currentTickerInterval time.Duration requestHeaders map[string]string @@ -69,10 +67,6 @@ func updateURI(uri string, lq LokiQueryRangeResponse, infinite bool) (string, er return u.String(), nil } -func (lc *LokiClient) SetTomb(t *tomb.Tomb) { - lc.t = t -} - func (lc *LokiClient) resetFailStart() { if !lc.failStart.IsZero() { log.Infof("loki is back after %s", time.Since(lc.failStart)) @@ -120,8 +114,6 @@ func (lc *LokiClient) queryRange(ctx context.Context, uri string, c chan *LokiQu select { case <-ctx.Done(): return ctx.Err() - case <-lc.t.Dying(): - return lc.t.Err() case <-ticker.C: resp, err := lc.Get(ctx, uri) if err != nil { @@ -218,9 +210,6 @@ func (lc *LokiClient) Ready(ctx context.Context) error { case <-ctx.Done(): tick.Stop() return ctx.Err() - case <-lc.t.Dying(): - tick.Stop() - return lc.t.Err() case <-tick.C: lc.Logger.Debug("Checking if Loki is ready") resp, err := lc.Get(ctx, url) @@ -273,7 +262,7 @@ func (lc *LokiClient) Tail(ctx context.Context) (chan *LokiResponse, error) { return responseChan, errors.New("error connecting to websocket") } - lc.t.Go(func() error { + go func() { defer conn.Close() for { jsonResponse := &LokiResponse{} @@ -281,12 +270,13 @@ func (lc *LokiClient) Tail(ctx context.Context) (chan *LokiResponse, error) { err = conn.ReadJSON(jsonResponse) if err != nil { lc.Logger.Errorf("Error reading from websocket: %s", err) - return fmt.Errorf("websocket error: %w", err) + close(responseChan) + return } responseChan <- jsonResponse } - }) + }() return responseChan, nil } @@ -305,9 +295,11 @@ func (lc *LokiClient) QueryRange(ctx context.Context, infinite bool) chan *LokiQ lc.Logger.Debugf("Since: %s (%s)", lc.config.Since, time.Now().Add(-lc.config.Since)) lc.Logger.Infof("Connecting to %s", url) - lc.t.Go(func() error { - return lc.queryRange(ctx, url, c, infinite) - }) + go func() { + if err := lc.queryRange(ctx, url, c, infinite); err != nil { + lc.Logger.Errorf("Error querying range: %s", err) + } + }() return c } diff --git a/pkg/acquisition/modules/loki/loki_test.go b/pkg/acquisition/modules/loki/loki_test.go index d79ad477cb7..e2b58c7f212 100644 --- a/pkg/acquisition/modules/loki/loki_test.go +++ b/pkg/acquisition/modules/loki/loki_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -15,7 +16,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - tomb "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -242,9 +242,7 @@ since: 1h } }() - lokiTomb := tomb.Tomb{} - - if err := lokiSource.OneShotAcquisition(ctx, out, &lokiTomb); err != nil { + if err := lokiSource.OneShot(ctx, out); err != nil { t.Fatalf("Unexpected error : %s", err) } @@ -307,7 +305,6 @@ query: > }) out := make(chan pipeline.Event) - lokiTomb := tomb.Tomb{} lokiSource := loki.Source{} err := lokiSource.Configure(ctx, []byte(ts.config), subLogger, metrics.AcquisitionMetricsLevelNone) @@ -315,46 +312,56 @@ query: > t.Fatalf("Unexpected error : %s", err) } - err = lokiSource.StreamingAcquisition(ctx, out, &lokiTomb) - cstest.AssertErrorContains(t, err, ts.streamErr) + streamCtx, streamCancel := context.WithCancel(ctx) + defer streamCancel() + + streamErrCh := make(chan error, 1) + go func() { + streamErrCh <- lokiSource.Stream(streamCtx, out) + }() if ts.streamErr != "" { + err = <-streamErrCh + cstest.AssertErrorContains(t, err, ts.streamErr) return } time.Sleep(time.Second * 2) // We need to give time to start reading from the WS - readTomb := tomb.Tomb{} readCtx, cancel := context.WithTimeout(ctx, time.Second*10) count := 0 + readErrCh := make(chan error, 1) - readTomb.Go(func() error { + go func() { defer cancel() for { select { case <-readCtx.Done(): - return readCtx.Err() + readErrCh <- readCtx.Err() + return case evt := <-out: count++ if !strings.HasSuffix(evt.Line.Raw, title) { - return fmt.Errorf("Incorrect suffix : %s", evt.Line.Raw) + readErrCh <- fmt.Errorf("Incorrect suffix : %s", evt.Line.Raw) + return } if count == ts.expectedLines { - return nil + readErrCh <- nil + return } } } - }) + }() err = feedLoki(ctx, subLogger, ts.expectedLines, title) if err != nil { t.Fatalf("Unexpected error : %s", err) } - err = readTomb.Wait() + err = <-readErrCh cancel() @@ -393,12 +400,12 @@ query: > out := make(chan pipeline.Event) - lokiTomb := &tomb.Tomb{} + streamCtx, streamCancel := context.WithCancel(ctx) + streamErrCh := make(chan error, 1) - err = lokiSource.StreamingAcquisition(ctx, out, lokiTomb) - if err != nil { - t.Fatalf("Unexpected error : %s", err) - } + go func() { + streamErrCh <- lokiSource.Stream(streamCtx, out) + }() time.Sleep(time.Second * 2) @@ -407,10 +414,10 @@ query: > t.Fatalf("Unexpected error : %s", err) } - lokiTomb.Kill(nil) + streamCancel() - err = lokiTomb.Wait() - if err != nil { + err = <-streamErrCh + if err != nil && !errors.Is(err, context.Canceled) { t.Fatalf("Unexpected error : %s", err) } } diff --git a/pkg/acquisition/modules/loki/run.go b/pkg/acquisition/modules/loki/run.go index 28bc8b94822..a3adb3ca693 100644 --- a/pkg/acquisition/modules/loki/run.go +++ b/pkg/acquisition/modules/loki/run.go @@ -10,17 +10,14 @@ import ( "fmt" "github.com/prometheus/client_golang/prometheus" - tomb "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/loki/internal/lokiclient" "github.com/crowdsecurity/crowdsec/pkg/metrics" "github.com/crowdsecurity/crowdsec/pkg/pipeline" ) -// OneShotAcquisition reads a set of file and returns when done -func (l *Source) OneShotAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { +func (l *Source) OneShot(ctx context.Context, out chan pipeline.Event) error { l.logger.Debug("Loki one shot acquisition") - l.Client.SetTomb(t) if !l.Config.NoReadyCheck { readyCtx, readyCancel := context.WithTimeout(ctx, l.Config.WaitForReady) @@ -38,7 +35,7 @@ func (l *Source) OneShotAcquisition(ctx context.Context, out chan pipeline.Event for { select { - case <-t.Dying(): + case <-ctx.Done(): l.logger.Debug("Loki one shot acquisition stopped") return nil case resp, ok := <-c: @@ -75,9 +72,7 @@ func (l *Source) readOneEntry(entry lokiclient.Entry, labels map[string]string, out <- evt } -func (l *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { - l.Client.SetTomb(t) - +func (l *Source) Stream(ctx context.Context, out chan pipeline.Event) error { if !l.Config.NoReadyCheck { readyCtx, readyCancel := context.WithTimeout(ctx, l.Config.WaitForReady) defer readyCancel() @@ -89,30 +84,23 @@ func (l *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Eve ll := l.logger.WithField("websocket_url", l.lokiWebsocket) - t.Go(func() error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - respChan := l.Client.QueryRange(ctx, true) + respChan := l.Client.QueryRange(ctx, true) - for { - select { - case resp, ok := <-respChan: - if !ok { - ll.Warnf("loki channel closed") - return errors.New("loki channel closed") - } + for { + select { + case resp, ok := <-respChan: + if !ok { + ll.Warnf("loki channel closed") + return errors.New("loki channel closed") + } - for _, stream := range resp.Data.Result { - for _, entry := range stream.Entries { - l.readOneEntry(entry, l.Config.Labels, out) - } + for _, stream := range resp.Data.Result { + for _, entry := range stream.Entries { + l.readOneEntry(entry, l.Config.Labels, out) } - case <-t.Dying(): - return nil } + case <-ctx.Done(): + return nil } - }) - - return nil + } } diff --git a/pkg/acquisition/modules/s3/init.go b/pkg/acquisition/modules/s3/init.go index 2e022494981..c38f0ad7f89 100644 --- a/pkg/acquisition/modules/s3/init.go +++ b/pkg/acquisition/modules/s3/init.go @@ -7,11 +7,11 @@ import ( var ( // verify interface compliance - _ types.DataSource = (*Source)(nil) - _ types.DSNConfigurer = (*Source)(nil) - _ types.Fetcher = (*Source)(nil) - _ types.Tailer = (*Source)(nil) - _ types.MetricsProvider = (*Source)(nil) + _ types.DataSource = (*Source)(nil) + _ types.DSNConfigurer = (*Source)(nil) + _ types.BatchFetcher = (*Source)(nil) + _ types.RestartableStreamer = (*Source)(nil) + _ types.MetricsProvider = (*Source)(nil) ) const ModuleName = "s3" diff --git a/pkg/acquisition/modules/s3/run.go b/pkg/acquisition/modules/s3/run.go index 6b2d51b9a5a..9ac06cc8e98 100644 --- a/pkg/acquisition/modules/s3/run.go +++ b/pkg/acquisition/modules/s3/run.go @@ -20,7 +20,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" + "golang.org/x/sync/errgroup" "github.com/crowdsecurity/crowdsec/pkg/metrics" "github.com/crowdsecurity/crowdsec/pkg/pipeline" @@ -67,26 +67,28 @@ const ( SQSFormatSNS = "sns" ) -func (s *Source) readManager() { +func (s *Source) readManager(ctx context.Context, out chan pipeline.Event, readerChan <-chan S3Object) { logger := s.logger.WithField("method", "readManager") for { select { - case <-s.t.Dying(): + case <-ctx.Done(): logger.Infof("Shutting down S3 read manager") - s.cancel() return - case s3Object := <-s.readerChan: + case s3Object, ok := <-readerChan: + if !ok { + return + } logger.Debugf("Reading file %s/%s", s3Object.Bucket, s3Object.Key) - if err := s.readFile(s3Object.Bucket, s3Object.Key); err != nil { + if err := s.readFile(ctx, out, s3Object.Bucket, s3Object.Key); err != nil { logger.Errorf("Error while reading file: %s", err) } } } } -func (s *Source) getBucketContent() ([]s3types.Object, error) { +func (s *Source) getBucketContent(ctx context.Context) ([]s3types.Object, error) { logger := s.logger.WithField("method", "getBucketContent") logger.Debugf("Getting bucket content") @@ -95,7 +97,7 @@ func (s *Source) getBucketContent() ([]s3types.Object, error) { var continuationToken *string for { - out, err := s.s3Client.ListObjectsV2(s.ctx, &s3.ListObjectsV2Input{ + out, err := s.s3Client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ Bucket: aws.String(s.Config.BucketName), Prefix: aws.String(s.Config.Prefix), ContinuationToken: continuationToken, @@ -120,7 +122,7 @@ func (s *Source) getBucketContent() ([]s3types.Object, error) { return bucketObjects, nil } -func (s *Source) listPoll() error { +func (s *Source) listPoll(ctx context.Context, readerChan chan<- S3Object) error { logger := s.logger.WithField("method", "listPoll") ticker := time.NewTicker(time.Duration(s.Config.PollingInterval) * time.Second) lastObjectDate := time.Now() @@ -129,14 +131,13 @@ func (s *Source) listPoll() error { for { select { - case <-s.t.Dying(): + case <-ctx.Done(): logger.Infof("Shutting down list poller") - s.cancel() return nil case <-ticker.C: newObject := false - bucketObjects, err := s.getBucketContent() + bucketObjects, err := s.getBucketContent(ctx) if err != nil { logger.Errorf("Error while getting bucket content: %s", err) continue @@ -161,9 +162,9 @@ func (s *Source) listPoll() error { } select { - case s.readerChan <- obj: - case <-s.t.Dying(): - logger.Debug("tomb is dying, dropping object send") + case readerChan <- obj: + case <-ctx.Done(): + logger.Debug("context canceled, dropping object send") return nil } } @@ -261,19 +262,18 @@ func (s *Source) extractBucketAndPrefix(message *string) (string, string, error) } } -func (s *Source) sqsPoll() error { +func (s *Source) sqsPoll(ctx context.Context, readerChan chan<- S3Object) error { logger := s.logger.WithField("method", "sqsPoll") for { select { - case <-s.t.Dying(): + case <-ctx.Done(): logger.Infof("Shutting down SQS poller") - s.cancel() return nil default: logger.Trace("Polling SQS queue") - out, err := s.sqsClient.ReceiveMessage(s.ctx, &sqs.ReceiveMessageInput{ + out, err := s.sqsClient.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{ QueueUrl: aws.String(s.Config.SQSName), MaxNumberOfMessages: 10, WaitTimeSeconds: 20, // Probably no need to make it configurable ? @@ -298,7 +298,7 @@ func (s *Source) sqsPoll() error { if err != nil { logger.Errorf("Error while parsing SQS message: %s", err) // Always delete the message to avoid infinite loop - _, err = s.sqsClient.DeleteMessage(s.ctx, + _, err = s.sqsClient.DeleteMessage(ctx, &sqs.DeleteMessageInput{ QueueUrl: aws.String(s.Config.SQSName), ReceiptHandle: message.ReceiptHandle, @@ -314,13 +314,13 @@ func (s *Source) sqsPoll() error { // don't block if readManager has quit select { - case s.readerChan <- S3Object{Key: key, Bucket: bucket}: - case <-s.t.Dying(): - logger.Debug("tomb is dying, dropping object send") + case readerChan <- S3Object{Key: key, Bucket: bucket}: + case <-ctx.Done(): + logger.Debug("context canceled, dropping object send") return nil } - _, err = s.sqsClient.DeleteMessage(s.ctx, + _, err = s.sqsClient.DeleteMessage(ctx, &sqs.DeleteMessageInput{ QueueUrl: aws.String(s.Config.SQSName), ReceiptHandle: message.ReceiptHandle, @@ -335,7 +335,7 @@ func (s *Source) sqsPoll() error { } } -func (s *Source) readFile(bucket string, key string) error { +func (s *Source) readFile(ctx context.Context, out chan pipeline.Event, bucket string, key string) error { // TODO: Handle SSE-C var scanner *bufio.Scanner @@ -345,7 +345,7 @@ func (s *Source) readFile(bucket string, key string) error { "key": key, }) - output, err := s.s3Client.GetObject(s.ctx, &s3.GetObjectInput{ + output, err := s.s3Client.GetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(bucket), Key: aws.String(key), }) @@ -385,7 +385,7 @@ func (s *Source) readFile(bucket string, key string) error { for scanner.Scan() { select { - case <-s.t.Dying(): + case <-ctx.Done(): s.logger.Infof("Shutting down reader for %s/%s", bucket, key) return nil default: @@ -415,9 +415,9 @@ func (s *Source) readFile(bucket string, key string) error { // don't block in shutdown select { - case s.out <-evt: - case <-s.t.Dying(): - s.logger.Infof("tomb is dying, dropping event for %s/%s", bucket, key) + case out <- evt: + case <-ctx.Done(): + s.logger.Infof("context canceled, dropping event for %s/%s", bucket, key) return nil } } @@ -434,58 +434,57 @@ func (s *Source) readFile(bucket string, key string) error { return nil } -func (s *Source) OneShotAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { +func (s *Source) OneShot(ctx context.Context, out chan pipeline.Event) error { s.logger.Infof("starting acquisition of %s/%s/%s", s.Config.BucketName, s.Config.Prefix, s.Config.Key) - s.out = out - s.ctx, s.cancel = context.WithCancel(ctx) s.Config.UseTimeMachine = true - s.t = t if s.Config.Key != "" { - err := s.readFile(s.Config.BucketName, s.Config.Key) + err := s.readFile(ctx, out, s.Config.BucketName, s.Config.Key) if err != nil { return err } } else { // No key, get everything in the bucket based on the prefix - objects, err := s.getBucketContent() + objects, err := s.getBucketContent(ctx) if err != nil { return err } for _, object := range objects { - err := s.readFile(s.Config.BucketName, *object.Key) + err := s.readFile(ctx, out, s.Config.BucketName, *object.Key) if err != nil { return err } } } - t.Kill(nil) - return nil } -func (s *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { - s.t = t - s.out = out - s.readerChan = make(chan S3Object, 100) // FIXME: does this needs to be buffered? - s.ctx, s.cancel = context.WithCancel(ctx) +func (s *Source) Stream(ctx context.Context, out chan pipeline.Event) error { + readerChan := make(chan S3Object, 100) // FIXME: does this needs to be buffered? s.logger.Infof("starting acquisition of %s/%s", s.Config.BucketName, s.Config.Prefix) - t.Go(func() error { - s.readManager() + + g, gctx := errgroup.WithContext(ctx) + g.Go(func() error { + s.readManager(gctx, out, readerChan) return nil }) if s.Config.PollingMethod == PollMethodSQS { - t.Go(func() error { - return s.sqsPoll() + g.Go(func() error { + return s.sqsPoll(gctx, readerChan) }) } else { - t.Go(func() error { - return s.listPoll() + g.Go(func() error { + return s.listPoll(gctx, readerChan) }) } - return nil + err := g.Wait() + if errors.Is(err, context.Canceled) { + return nil + } + + return err } diff --git a/pkg/acquisition/modules/s3/s3_test.go b/pkg/acquisition/modules/s3/s3_test.go index 8557a8fa37e..e0d3fa463df 100644 --- a/pkg/acquisition/modules/s3/s3_test.go +++ b/pkg/acquisition/modules/s3/s3_test.go @@ -17,7 +17,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/metrics" "github.com/crowdsecurity/crowdsec/pkg/pipeline" @@ -187,8 +186,7 @@ func TestDSNAcquis(t *testing.T) { } }() - tmb := tomb.Tomb{} - err = f.OneShotAcquisition(ctx, out, &tmb) + err = f.OneShot(ctx, out) require.NoError(t, err) time.Sleep(2 * time.Second) @@ -246,7 +244,9 @@ prefix: foo/ } out := make(chan pipeline.Event) - tb := tomb.Tomb{} + done := make(chan struct{}) + streamCtx, cancel := context.WithCancel(ctx) + streamErr := make(chan error, 1) go func() { for { @@ -254,21 +254,20 @@ prefix: foo/ case s := <-out: fmt.Printf("got line %s\n", s.Line.Raw) linesRead++ - case <-tb.Dying(): + case <-done: return } } }() - err = f.StreamingAcquisition(ctx, out, &tb) - if err != nil { - t.Fatalf("unexpected error: %s", err.Error()) - } + go func() { + streamErr <- f.Stream(streamCtx, out) + }() time.Sleep(2 * time.Second) - tb.Kill(nil) - err = tb.Wait() - require.NoError(t, err) + cancel() + require.NoError(t, <-streamErr) + close(done) assert.Equal(t, test.expectedCount, linesRead) }) } @@ -342,7 +341,9 @@ sqs_name: test } out := make(chan pipeline.Event) - tb := tomb.Tomb{} + done := make(chan struct{}) + streamCtx, cancel := context.WithCancel(ctx) + streamErr := make(chan error, 1) go func() { for { @@ -350,22 +351,20 @@ sqs_name: test case s := <-out: fmt.Printf("got line %s\n", s.Line.Raw) linesRead++ - case <-tb.Dying(): + case <-done: return } } }() - err = f.StreamingAcquisition(ctx, out, &tb) - if err != nil { - t.Fatalf("unexpected error: %s", err.Error()) - } + go func() { + streamErr <- f.Stream(streamCtx, out) + }() time.Sleep(2 * time.Second) - tb.Kill(nil) - - err = tb.Wait() - require.NoError(t, err) + cancel() + require.NoError(t, <-streamErr) + close(done) assert.Equal(t, test.expectedCount, linesRead) }) } diff --git a/pkg/acquisition/modules/s3/source.go b/pkg/acquisition/modules/s3/source.go index a1784f5d68d..851adc82409 100644 --- a/pkg/acquisition/modules/s3/source.go +++ b/pkg/acquisition/modules/s3/source.go @@ -2,13 +2,12 @@ package s3acquisition import ( "context" + s3Manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/sqs" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/metrics" - "github.com/crowdsecurity/crowdsec/pkg/pipeline" ) type S3API interface { @@ -27,11 +26,6 @@ type Source struct { logger *log.Entry s3Client S3API sqsClient SQSAPI - readerChan chan S3Object - t *tomb.Tomb - out chan pipeline.Event - ctx context.Context - cancel context.CancelFunc } type S3Object struct { diff --git a/pkg/acquisition/modules/victorialogs/init.go b/pkg/acquisition/modules/victorialogs/init.go index 5ab9e5eb8dd..08b842f677f 100644 --- a/pkg/acquisition/modules/victorialogs/init.go +++ b/pkg/acquisition/modules/victorialogs/init.go @@ -7,11 +7,11 @@ import ( var ( // verify interface compliance - _ types.DataSource = (*Source)(nil) - _ types.DSNConfigurer = (*Source)(nil) - _ types.Fetcher = (*Source)(nil) - _ types.Tailer = (*Source)(nil) - _ types.MetricsProvider = (*Source)(nil) + _ types.DataSource = (*Source)(nil) + _ types.DSNConfigurer = (*Source)(nil) + _ types.BatchFetcher = (*Source)(nil) + _ types.RestartableStreamer = (*Source)(nil) + _ types.MetricsProvider = (*Source)(nil) ) const ModuleName = "victorialogs" diff --git a/pkg/acquisition/modules/victorialogs/internal/vlclient/vl_client.go b/pkg/acquisition/modules/victorialogs/internal/vlclient/vl_client.go index 81df3f0bd1e..136a030d8a8 100644 --- a/pkg/acquisition/modules/victorialogs/internal/vlclient/vl_client.go +++ b/pkg/acquisition/modules/victorialogs/internal/vlclient/vl_client.go @@ -15,7 +15,6 @@ import ( "time" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" "maps" @@ -25,7 +24,6 @@ type VLClient struct { Logger *log.Entry config Config - t *tomb.Tomb failStart time.Time currentTickerInterval time.Duration requestHeaders map[string]string @@ -68,10 +66,6 @@ func updateURI(uri string, newStart time.Time) (string, error) { return u.String(), nil } -func (lc *VLClient) SetTomb(t *tomb.Tomb) { - lc.t = t -} - func (lc *VLClient) shouldRetry() bool { if lc.failStart.IsZero() { lc.Logger.Warningf("VictoriaLogs is not available, will retry for %s", lc.config.FailMaxDuration) @@ -118,8 +112,6 @@ func (lc *VLClient) doQueryRange(ctx context.Context, uri string, c chan *Log, i select { case <-ctx.Done(): return ctx.Err() - case <-lc.t.Dying(): - return lc.t.Err() case <-ticker.C: resp, err := lc.Get(ctx, uri) if err != nil { @@ -268,9 +260,6 @@ func (lc *VLClient) Ready(ctx context.Context) error { case <-ctx.Done(): tick.Stop() return ctx.Err() - case <-lc.t.Dying(): - tick.Stop() - return lc.t.Err() case <-tick.C: lc.Logger.Debug("Checking if VictoriaLogs is ready") @@ -313,10 +302,14 @@ func (lc *VLClient) Tail(ctx context.Context) (chan *Log, error) { ) for { - resp, err = lc.Get(ctx, u) + resp, err = lc.Get(ctx, u) //nolint:bodyclose // body is closed in error paths and via defer in the goroutine lc.Logger.Tracef("Tail request done: %v | %s", resp, err) if err != nil { + if resp != nil { + resp.Body.Close() + } + if errors.Is(err, context.Canceled) { return nil, nil } @@ -343,14 +336,17 @@ func (lc *VLClient) Tail(ctx context.Context) (chan *Log, error) { responseChan := make(chan *Log) - lc.t.Go(func() error { + lc.Logger.Infof("Connecting to %s", u) + + go func() { + defer resp.Body.Close() + _, _, err = lc.readResponse(ctx, resp, responseChan) if err != nil { - return fmt.Errorf("error while reading tail response: %w", err) + lc.Logger.Errorf("error while reading tail response: %s", err) } - - return nil - }) + close(responseChan) + }() return responseChan, nil } @@ -370,9 +366,11 @@ func (lc *VLClient) QueryRange(ctx context.Context, infinite bool) chan *Log { lc.Logger.Debugf("Since: %s (%s)", lc.config.Since, t) lc.Logger.Infof("Connecting to %s", u) - lc.t.Go(func() error { - return lc.doQueryRange(ctx, u, c, infinite) - }) + go func() { + if err := lc.doQueryRange(ctx, u, c, infinite); err != nil { + lc.Logger.Errorf("Error querying range: %s", err) + } + }() return c } diff --git a/pkg/acquisition/modules/victorialogs/run.go b/pkg/acquisition/modules/victorialogs/run.go index 98eb8eebd76..24a1cd83cce 100644 --- a/pkg/acquisition/modules/victorialogs/run.go +++ b/pkg/acquisition/modules/victorialogs/run.go @@ -5,7 +5,6 @@ import ( "fmt" "github.com/prometheus/client_golang/prometheus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/victorialogs/internal/vlclient" @@ -13,10 +12,8 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/pipeline" ) -// OneShotAcquisition reads a set of file and returns when done -func (s *Source) OneShotAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { +func (s *Source) OneShot(ctx context.Context, out chan pipeline.Event) error { s.logger.Debug("VictoriaLogs one shot acquisition") - s.Client.SetTomb(t) readyCtx, cancel := context.WithTimeout(ctx, s.Config.WaitForReady) defer cancel() @@ -36,7 +33,7 @@ func (s *Source) OneShotAcquisition(ctx context.Context, out chan pipeline.Event for { select { - case <-t.Dying(): + case <-ctx.Done(): s.logger.Debug("VictoriaLogs one shot acquisition stopped") return nil case resp, ok := <-respChan: @@ -76,9 +73,7 @@ func (s *Source) readOneEntry(entry *vlclient.Log, labels map[string]string, out } } -func (s *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { - s.Client.SetTomb(t) - +func (s *Source) Stream(ctx context.Context, out chan pipeline.Event) error { readyCtx, cancel := context.WithTimeout(ctx, s.Config.WaitForReady) defer cancel() @@ -87,44 +82,25 @@ func (s *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Eve return fmt.Errorf("VictoriaLogs is not ready: %w", err) } - lctx, clientCancel := context.WithCancel(ctx) - // Don't defer clientCancel(), the client outlives this function call - - t.Go(func() error { - <-t.Dying() - clientCancel() - - return nil - }) - - t.Go(func() error { - respChan, err := s.getResponseChan(lctx, true) - if err != nil { - clientCancel() - s.logger.Errorf("could not start VictoriaLogs tail: %s", err) - - return fmt.Errorf("while starting VictoriaLogs tail: %w", err) - } - - for { - select { - case resp, ok := <-respChan: - if !ok { - s.logger.Warnf("VictoriaLogs channel closed") - clientCancel() - - return err - } + respChan, err := s.getResponseChan(ctx, true) + if err != nil { + s.logger.Errorf("could not start VictoriaLogs tail: %s", err) + return fmt.Errorf("while starting VictoriaLogs tail: %w", err) + } - s.readOneEntry(resp, s.Config.Labels, out) - case <-t.Dying(): - clientCancel() - return nil + for { + select { + case resp, ok := <-respChan: + if !ok { + s.logger.Warnf("VictoriaLogs channel closed") + return err } - } - }) - return nil + s.readOneEntry(resp, s.Config.Labels, out) + case <-ctx.Done(): + return nil + } + } } func (s *Source) getResponseChan(ctx context.Context, infinite bool) (chan *vlclient.Log, error) { diff --git a/pkg/acquisition/modules/victorialogs/victorialogs_test.go b/pkg/acquisition/modules/victorialogs/victorialogs_test.go index d082986694a..b434ff3e3ba 100644 --- a/pkg/acquisition/modules/victorialogs/victorialogs_test.go +++ b/pkg/acquisition/modules/victorialogs/victorialogs_test.go @@ -17,7 +17,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -211,9 +210,7 @@ since: 1h } }() - vlTomb := tomb.Tomb{} - - err = vlSource.OneShotAcquisition(ctx, out, &vlTomb) + err = vlSource.OneShot(ctx, out) if err != nil { t.Fatalf("Unexpected error : %s", err) } @@ -273,7 +270,6 @@ query: > }) out := make(chan pipeline.Event) - vlTomb := tomb.Tomb{} vlSource := victorialogs.Source{} err := vlSource.Configure(ctx, []byte(ts.config), subLogger, metrics.AcquisitionMetricsLevelNone) @@ -281,48 +277,58 @@ query: > t.Fatalf("Unexpected error : %s", err) } - err = vlSource.StreamingAcquisition(ctx, out, &vlTomb) - cstest.AssertErrorContains(t, err, ts.streamErr) + streamCtx, streamCancel := context.WithCancel(ctx) + defer streamCancel() + streamErrCh := make(chan error, 1) + + go func() { + streamErrCh <- vlSource.Stream(streamCtx, out) + }() if ts.streamErr != "" { + err = <-streamErrCh + cstest.AssertErrorContains(t, err, ts.streamErr) return } time.Sleep(time.Second * 2) // We need to give time to start reading from the WS - readTomb := tomb.Tomb{} - readCtx, cancel := context.WithTimeout(ctx, time.Second*10) + readCtx, readCancel := context.WithTimeout(ctx, time.Second*10) count := 0 + readErrCh := make(chan error, 1) - readTomb.Go(func() error { - defer cancel() + go func() { + defer readCancel() for { select { case <-readCtx.Done(): - return readCtx.Err() + readErrCh <- readCtx.Err() + return case evt := <-out: count++ if !strings.HasSuffix(evt.Line.Raw, title) { - return fmt.Errorf("Incorrect suffix : %s", evt.Line.Raw) + readErrCh <- fmt.Errorf("Incorrect suffix : %s", evt.Line.Raw) + return } if count == ts.expectedLines { - return nil + readErrCh <- nil + return } } } - }) + }() err = feedVLogs(ctx, subLogger, ts.expectedLines, title) if err != nil { t.Fatalf("Unexpected error : %s", err) } - err = readTomb.Wait() + err = <-readErrCh - cancel() + readCancel() if err != nil { t.Fatalf("Unexpected error : %s", err) @@ -357,12 +363,13 @@ query: > out := make(chan pipeline.Event, 10) - vlTomb := &tomb.Tomb{} + streamCtx, streamCancel := context.WithCancel(ctx) + streamDone := make(chan struct{}) - err = vlSource.StreamingAcquisition(ctx, out, vlTomb) - if err != nil { - t.Fatalf("Unexpected error : %s", err) - } + go func() { + _ = vlSource.Stream(streamCtx, out) + close(streamDone) + }() time.Sleep(time.Second * 2) @@ -371,10 +378,7 @@ query: > t.Fatalf("Unexpected error : %s", err) } - vlTomb.Kill(nil) + streamCancel() - err = vlTomb.Wait() - if err != nil { - t.Fatalf("Unexpected error : %s", err) - } + <-streamDone } diff --git a/pkg/acquisition/modules/wineventlog/init.go b/pkg/acquisition/modules/wineventlog/init.go index 9befef73d7d..bda982a210b 100644 --- a/pkg/acquisition/modules/wineventlog/init.go +++ b/pkg/acquisition/modules/wineventlog/init.go @@ -7,11 +7,11 @@ import ( var ( // verify interface compliance - _ types.DataSource = (*Source)(nil) - _ types.DSNConfigurer = (*Source)(nil) - _ types.BatchFetcher = (*Source)(nil) - _ types.Tailer = (*Source)(nil) - _ types.MetricsProvider = (*Source)(nil) + _ types.DataSource = (*Source)(nil) + _ types.DSNConfigurer = (*Source)(nil) + _ types.BatchFetcher = (*Source)(nil) + _ types.RestartableStreamer = (*Source)(nil) + _ types.MetricsProvider = (*Source)(nil) ) const ModuleName = "wineventlog" diff --git a/pkg/acquisition/modules/wineventlog/run_windows.go b/pkg/acquisition/modules/wineventlog/run_windows.go index 10c86869148..85ee9e1a61a 100644 --- a/pkg/acquisition/modules/wineventlog/run_windows.go +++ b/pkg/acquisition/modules/wineventlog/run_windows.go @@ -11,9 +11,6 @@ import ( "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" "golang.org/x/sys/windows" - "gopkg.in/tomb.v2" - - "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/crowdsec/pkg/metrics" "github.com/crowdsecurity/crowdsec/pkg/pipeline" @@ -63,7 +60,7 @@ func (s *Source) getXMLEvents(config *winlog.SubscribeConfig, publisherCache map return renderedEvents, err } -func (s *Source) getEvents(out chan pipeline.Event, t *tomb.Tomb) error { +func (s *Source) getEvents(ctx context.Context, out chan pipeline.Event) error { subscription, err := winlog.Subscribe(s.evtConfig) if err != nil { s.logger.Errorf("Failed to subscribe to event log: %s", err) @@ -78,7 +75,7 @@ func (s *Source) getEvents(out chan pipeline.Event, t *tomb.Tomb) error { }() for { select { - case <-t.Dying(): + case <-ctx.Done(): s.logger.Infof("wineventlog is dying") return nil default: @@ -170,10 +167,6 @@ OUTER_LOOP: return nil } -func (s *Source) StreamingAcquisition(ctx context.Context, out chan pipeline.Event, t *tomb.Tomb) error { - t.Go(func() error { - defer trace.ReportPanic() - return s.getEvents(out, t) - }) - return nil +func (s *Source) Stream(ctx context.Context, out chan pipeline.Event) error { + return s.getEvents(ctx, out) } diff --git a/pkg/acquisition/modules/wineventlog/stub.go b/pkg/acquisition/modules/wineventlog/stub.go index 697968e0911..788f32346c6 100644 --- a/pkg/acquisition/modules/wineventlog/stub.go +++ b/pkg/acquisition/modules/wineventlog/stub.go @@ -8,7 +8,6 @@ import ( "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" "github.com/crowdsecurity/crowdsec/pkg/metrics" @@ -61,7 +60,7 @@ func (*Source) CanRun() error { return errors.New("windows event log acquisition is only supported on Windows") } -func (*Source) StreamingAcquisition(_ context.Context, _ chan pipeline.Event, _ *tomb.Tomb) error { +func (*Source) Stream(_ context.Context, _ chan pipeline.Event) error { return nil } diff --git a/pkg/acquisition/modules/wineventlog/wineventlog_windows_test.go b/pkg/acquisition/modules/wineventlog/wineventlog_windows_test.go index 7b94228c1a0..d689d16b87e 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog_windows_test.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog_windows_test.go @@ -3,6 +3,7 @@ package wineventlogacquisition import ( + "context" "testing" "time" @@ -10,7 +11,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/sys/windows/svc/eventlog" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -141,8 +141,6 @@ event_level: bla`, } func TestLiveAcquisition(t *testing.T) { - ctx := t.Context() - err := exprhelpers.Init(nil) require.NoError(t, err) @@ -198,15 +196,17 @@ event_ids: } for _, test := range tests { - to := &tomb.Tomb{} + ctx, cancel := context.WithCancel(t.Context()) c := make(chan pipeline.Event) f := Source{} err := f.Configure(ctx, []byte(test.config), subLogger, metrics.AcquisitionMetricsLevelNone) require.NoError(t, err) - err = f.StreamingAcquisition(ctx, c, to) - require.NoError(t, err) + streamErr := make(chan error, 1) + go func() { + streamErr <- f.Stream(ctx, c) + }() time.Sleep(time.Second) lines := test.expectedLines @@ -238,8 +238,8 @@ event_ids: } else { assert.Equal(t, test.expectedLines, linesRead) } - to.Kill(nil) - _ = to.Wait() + cancel() + <-streamErr } } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 23a022901b6..297b2b2b51b 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -18,7 +18,6 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/go-openapi/strfmt" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/ptr" @@ -61,9 +60,10 @@ type apic struct { AlertsAddChan chan []*models.Alert mu sync.Mutex - pushTomb tomb.Tomb - pullTomb tomb.Tomb - metricsTomb tomb.Tomb + metricsCancel context.CancelFunc + metricsDone chan struct{} + stopChan chan struct{} + stopOnce sync.Once startup bool consoleConfig *csconfig.ConsoleConfig isPulling chan bool @@ -193,9 +193,7 @@ func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient dbClient: dbClient, mu: sync.Mutex{}, startup: true, - pullTomb: tomb.Tomb{}, - pushTomb: tomb.Tomb{}, - metricsTomb: tomb.Tomb{}, + stopChan: make(chan struct{}), consoleConfig: consoleConfig, pullInterval: pullIntervalDefault, pullIntervalFirst: randomDuration(pullIntervalDefault, pullIntervalDelta), @@ -288,23 +286,28 @@ func (a *apic) Push(ctx context.Context) error { var cache models.AddSignalsRequest ticker := time.NewTicker(a.pushIntervalFirst) + defer ticker.Stop() log.Infof("Start push to CrowdSec Central API (interval: %s once, then %s)", a.pushIntervalFirst.Round(time.Second), a.pushInterval) - for { - select { - case <-a.pushTomb.Dying(): // if one apic routine is dying, do we kill the others? - a.pullTomb.Kill(nil) - a.metricsTomb.Kill(nil) - log.Infof("push tomb is dying, sending cache (%d elements) before exiting", len(cache)) + flushCache := func() error { + log.Infof("stopping push routine, sending cache (%d elements) before exiting", len(cache)) - if len(cache) == 0 { - return nil - } + if len(cache) == 0 { + return nil + } - go a.Send(ctx, &cache) + go a.Send(context.WithoutCancel(ctx), &cache) - return nil + return nil + } + + for { + select { + case <-ctx.Done(): + return flushCache() + case <-a.stopChan: + return flushCache() case <-ticker.C: ticker.Reset(a.pushInterval) @@ -1038,6 +1041,14 @@ func (a *apic) Pull(ctx context.Context) error { toldOnce := false for { + select { + case <-ctx.Done(): + return nil + case <-a.stopChan: + return nil + default: + } + scenario, err := a.FetchScenariosListFromDB(ctx) if err != nil { log.Errorf("unable to fetch scenarios from db: %s", err) @@ -1062,6 +1073,7 @@ func (a *apic) Pull(ctx context.Context) error { log.Infof("Start pull from CrowdSec Central API (interval: %s once, then %s)", a.pullIntervalFirst.Round(time.Second), a.pullInterval) ticker := time.NewTicker(a.pullIntervalFirst) + defer ticker.Stop() for { select { @@ -1072,19 +1084,72 @@ func (a *apic) Pull(ctx context.Context) error { log.Errorf("capi pull top: %s", err) continue } - case <-a.pullTomb.Dying(): // if one apic routine is dying, do we kill the others? - a.metricsTomb.Kill(nil) - a.pushTomb.Kill(nil) - + case <-ctx.Done(): return nil + case <-a.stopChan: + return nil + } + } +} + +func (a *apic) StartMetrics(ctx context.Context, sendUsageMetrics bool) { + a.mu.Lock() + if a.metricsCancel != nil { + a.mu.Unlock() + return + } + + metricsCtx, cancel := context.WithCancel(ctx) + done := make(chan struct{}) + a.metricsCancel = cancel + a.metricsDone = done + a.mu.Unlock() + + go func() { + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + a.SendMetrics(metricsCtx, make(chan bool)) + }() + + if sendUsageMetrics { + wg.Add(1) + go func() { + defer wg.Done() + a.SendUsageMetrics(metricsCtx) + }() } + + wg.Wait() + close(done) + }() +} + +func (a *apic) StopMetrics() { + a.mu.Lock() + cancel := a.metricsCancel + done := a.metricsDone + a.metricsCancel = nil + a.metricsDone = nil + a.mu.Unlock() + + if cancel != nil { + cancel() + } + + if done != nil { + <-done } } func (a *apic) Shutdown() { - a.pushTomb.Kill(nil) - a.pullTomb.Kill(nil) - a.metricsTomb.Kill(nil) + a.stopOnce.Do(func() { + close(a.stopChan) + }) + + a.StopMetrics() } func makeAddAndDeleteCounters() (map[string]map[string]int, map[string]map[string]int) { diff --git a/pkg/apiserver/apic_metrics.go b/pkg/apiserver/apic_metrics.go index bc49b0cdc73..1cae300e2d7 100644 --- a/pkg/apiserver/apic_metrics.go +++ b/pkg/apiserver/apic_metrics.go @@ -291,13 +291,14 @@ func (a *apic) SendMetrics(ctx context.Context, stop chan bool) { checkTicker := time.NewTicker(checkInt) metTicker := time.NewTicker(nextMetInt()) + defer checkTicker.Stop() + defer metTicker.Stop() for { select { case <-stop: - checkTicker.Stop() - metTicker.Stop() - + return + case <-ctx.Done(): return case <-checkTicker.C: oldIDs := machineIDs @@ -326,13 +327,6 @@ func (a *apic) SendMetrics(ctx context.Context, stop chan bool) { } metTicker.Reset(nextMetInt()) - case <-a.metricsTomb.Dying(): // if one apic routine is dying, do we kill the others? - checkTicker.Stop() - metTicker.Stop() - a.pullTomb.Kill(nil) - a.pushTomb.Kill(nil) - - return } } } @@ -345,8 +339,7 @@ func (a *apic) SendUsageMetrics(ctx context.Context) { for { select { - case <-a.metricsTomb.Dying(): - // The normal metrics routine also kills push/pull tombs, does that make sense ? + case <-ctx.Done(): ticker.Stop() return case <-ticker.C: diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 3015a2a24ba..8fe11eca459 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -19,7 +19,6 @@ import ( "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/go-cs-lib/ptr" @@ -57,12 +56,10 @@ func getAPIC(t *testing.T, ctx context.Context) *apic { return &apic{ AlertsAddChan: make(chan []*models.Alert), // DecisionDeleteChan: make(chan []*models.Decision), - dbClient: dbClient, - mu: sync.Mutex{}, - startup: true, - pullTomb: tomb.Tomb{}, - pushTomb: tomb.Tomb{}, - metricsTomb: tomb.Tomb{}, + dbClient: dbClient, + mu: sync.Mutex{}, + startup: true, + stopChan: make(chan struct{}), consoleConfig: &csconfig.ConsoleConfig{ ShareManualDecisions: ptr.Of(false), ShareTaintedScenarios: ptr.Of(false), diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index a06c1ba9bde..1c440d3099b 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -17,7 +17,6 @@ import ( "github.com/gin-gonic/gin" "github.com/go-co-op/gocron/v2" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/trace" @@ -41,7 +40,6 @@ type APIServer struct { httpServer *http.Server apic *apic papi *Papi - httpServerTomb tomb.Tomb } func isBrokenConnection(maybeError any) bool { @@ -240,7 +238,6 @@ func NewServer(ctx context.Context, config *csconfig.LocalApiServerCfg, accessLo router: router, apic: apiClient, papi: papiClient, - httpServerTomb: tomb.Tomb{}, }, nil } @@ -275,37 +272,25 @@ func (s *APIServer) papiPull(ctx context.Context) error { return nil } -func (s *APIServer) papiSync(ctx context.Context) error { - if err := s.papi.SyncDecisions(ctx); err != nil { - log.Errorf("capi decisions sync: %s", err) - return err - } - - return nil -} - func (s *APIServer) initAPIC(ctx context.Context) { - s.apic.pushTomb.Go(func() error { + go func() { defer trace.ReportPanic() - return s.apicPush(ctx) - }) - s.apic.pullTomb.Go(func() error { + _ = s.apicPush(ctx) + }() + go func() { defer trace.ReportPanic() - return s.apicPull(ctx) - }) + _ = s.apicPull(ctx) + }() if s.apic.apiClient.IsEnrolled() { if s.papi != nil { if s.papi.URL != "" { log.Info("Starting PAPI decision receiver") - s.papi.pullTomb.Go(func() error { + go func() { defer trace.ReportPanic() - return s.papiPull(ctx) - }) - s.papi.syncTomb.Go(func() error { - defer trace.ReportPanic() - return s.papiSync(ctx) - }) + _ = s.papiPull(ctx) + }() + s.papi.StartSync(ctx) } else { log.Warnf("papi_url is not set in online_api_credentials.yaml, can't synchronize with the console. Run cscli console enable console_management to add it.") } @@ -314,19 +299,7 @@ func (s *APIServer) initAPIC(ctx context.Context) { } } - s.apic.metricsTomb.Go(func() error { - defer trace.ReportPanic() - s.apic.SendMetrics(ctx, make(chan bool)) - return nil - }) - - if !s.cfg.DisableUsageMetricsExport { - s.apic.metricsTomb.Go(func() error { - defer trace.ReportPanic() - s.apic.SendUsageMetrics(ctx) - return nil - }) - } + s.apic.StartMetrics(ctx, !s.cfg.DisableUsageMetricsExport) } func (s *APIServer) Run(ctx context.Context, apiReady chan bool) error { @@ -350,11 +323,7 @@ func (s *APIServer) Run(ctx context.Context, apiReady chan bool) error { s.initAPIC(ctx) } - s.httpServerTomb.Go(func() error { - return s.listenAndServeLAPI(ctx, apiReady) - }) - - if err := s.httpServerTomb.Wait(); err != nil { + if err := s.listenAndServeLAPI(ctx, apiReady); err != nil { return fmt.Errorf("local API server stopped with error: %w", err) } @@ -390,7 +359,10 @@ func (s *APIServer) listenAndServeLAPI(ctx context.Context, apiReady chan bool) switch { case errors.Is(err, http.ErrServerClosed): - break + select { + case serverError <- nil: + default: + } case err != nil: serverError <- err } @@ -439,10 +411,10 @@ func (s *APIServer) listenAndServeLAPI(ctx context.Context, apiReady chan bool) select { case err := <-serverError: return err - case <-s.httpServerTomb.Dying(): + case <-ctx.Done(): log.Info("Shutting down API server") - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) defer cancel() if err := s.httpServer.Shutdown(ctx); err != nil { @@ -495,12 +467,6 @@ func (s *APIServer) Shutdown(ctx context.Context) error { pipe.Close() } - s.httpServerTomb.Kill(nil) - - if err := s.httpServerTomb.Wait(); err != nil { - return fmt.Errorf("while waiting on httpServerTomb: %w", err) - } - return nil } diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index f7b512e9a7a..113272c5a25 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -10,7 +10,6 @@ import ( "time" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/csconfig" @@ -65,8 +64,8 @@ type Papi struct { apiClient *apiclient.ApiClient Channels *OperationChannels mu sync.Mutex - pullTomb tomb.Tomb - syncTomb tomb.Tomb + syncCancel context.CancelFunc + syncDone chan struct{} SyncInterval time.Duration consoleConfig *csconfig.ConsoleConfig Logger *log.Entry @@ -113,8 +112,6 @@ func NewPAPI(apic *apic, dbClient *database.Client, consoleConfig *csconfig.Cons Channels: channels, SyncInterval: SyncInterval, mu: sync.Mutex{}, - pullTomb: tomb.Tomb{}, - syncTomb: tomb.Tomb{}, apiClient: apic.apiClient, apic: apic, consoleConfig: consoleConfig, @@ -296,7 +293,13 @@ func (p *Papi) Pull(ctx context.Context) error { papiChan = nil p.Logger.Debug("done stopping PAPI pull") } - case event := <-papiChan: + case event, ok := <-papiChan: + if !ok { + p.Logger.Debug("PAPI event stream closed") + papiChan = nil + continue + } + logger := p.Logger.WithField("request-id", event.RequestId) // update last timestamp in database newTime := time.Now().UTC() @@ -323,22 +326,59 @@ func (p *Papi) Pull(ctx context.Context) error { } } +func (p *Papi) StartSync(ctx context.Context) { + p.mu.Lock() + if p.syncCancel != nil { + p.mu.Unlock() + return + } + + syncCtx, cancel := context.WithCancel(ctx) + done := make(chan struct{}) + p.syncCancel = cancel + p.syncDone = done + p.mu.Unlock() + + go func() { + defer close(done) + _ = p.SyncDecisions(syncCtx) + }() +} + +func (p *Papi) StopSync() { + p.mu.Lock() + cancel := p.syncCancel + done := p.syncDone + p.syncCancel = nil + p.syncDone = nil + p.mu.Unlock() + + if cancel != nil { + cancel() + } + + if done != nil { + <-done + } +} + func (p *Papi) SyncDecisions(ctx context.Context) error { var cache models.DecisionsDeleteRequest ticker := time.NewTicker(p.SyncInterval) + defer ticker.Stop() p.Logger.Infof("Start decisions sync to CrowdSec Central API (interval: %s)", p.SyncInterval) for { select { - case <-p.syncTomb.Dying(): // if one apic routine is dying, do we kill the others? - p.Logger.Infof("sync decisions tomb is dying, sending cache (%d elements) before exiting", len(cache)) + case <-ctx.Done(): + p.Logger.Infof("sync decisions context canceled, sending cache (%d elements) before exiting", len(cache)) if len(cache) == 0 { return nil } - go p.SendDeletedDecisions(ctx, &cache) + go p.SendDeletedDecisions(context.WithoutCancel(ctx), &cache) return nil case <-ticker.C: @@ -398,7 +438,7 @@ func (p *Papi) SendDeletedDecisions(ctx context.Context, cacheOrig *models.Decis func (p *Papi) Shutdown() { p.Logger.Infof("Shutting down PAPI") - p.syncTomb.Kill(nil) + p.StopSync() select { case p.stopChan <- struct{}{}: // Cancel any HTTP request still in progress default: diff --git a/pkg/appsec/allowlists/allowlists.go b/pkg/appsec/allowlists/allowlists.go index 6e3286ce2c9..34e70a9b124 100644 --- a/pkg/appsec/allowlists/allowlists.go +++ b/pkg/appsec/allowlists/allowlists.go @@ -8,7 +8,6 @@ import ( "time" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/apiclient" ) @@ -33,7 +32,6 @@ type AppsecAllowlist struct { ranges []rangeAllowlist lock sync.RWMutex logger *log.Entry - tomb *tomb.Tomb } func NewAppsecAllowlist(logger *log.Entry) *AppsecAllowlist { @@ -107,6 +105,7 @@ func (a *AppsecAllowlist) FetchAllowlists(ctx context.Context) error { func (a *AppsecAllowlist) updateAllowlists(ctx context.Context) { ticker := time.NewTicker(allowlistRefreshInterval) + defer ticker.Stop() for { select { @@ -114,19 +113,14 @@ func (a *AppsecAllowlist) updateAllowlists(ctx context.Context) { if err := a.FetchAllowlists(ctx); err != nil { a.logger.Errorf("failed to fetch allowlists: %s", err) } - case <-a.tomb.Dying(): - ticker.Stop() + case <-ctx.Done(): return } } } -func (a *AppsecAllowlist) StartRefresh(ctx context.Context, t *tomb.Tomb) { - a.tomb = t - a.tomb.Go(func() error { - a.updateAllowlists(ctx) - return nil - }) +func (a *AppsecAllowlist) StartRefresh(ctx context.Context) { + go a.updateAllowlists(ctx) } func (a *AppsecAllowlist) IsAllowlisted(sourceIP string) (bool, string) { diff --git a/pkg/csplugin/broker.go b/pkg/csplugin/broker.go index 3cb8d791a09..b42a050e9c1 100644 --- a/pkg/csplugin/broker.go +++ b/pkg/csplugin/broker.go @@ -17,7 +17,6 @@ import ( "github.com/google/uuid" plugin "github.com/hashicorp/go-plugin" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "gopkg.in/yaml.v2" "github.com/crowdsecurity/go-cs-lib/csstring" @@ -40,17 +39,17 @@ const ( // It receives all the events from the main process and stacks them up // It is as well notified by the watcher when it needs to deliver events to plugins (based on time or count threshold) type PluginBroker struct { - PluginChannel chan models.ProfileAlert - alertsByPluginName map[string][]*models.Alert - profileConfigs []*csconfig.ProfileCfg - pluginConfigByName map[string]PluginConfig - pluginMap map[string]plugin.Plugin - notificationPluginByName map[string]protobufs.NotifierServer - watcher PluginWatcher - pluginKillMethods []func() - pluginProcConfig *csconfig.PluginCfg - pluginsTypesToDispatch map[string]struct{} - newBackoff backoffFactory + PluginChannel chan models.ProfileAlert + alertsByPluginName map[string][]*models.Alert + profileConfigs []*csconfig.ProfileCfg + pluginConfigByName map[string]PluginConfig + pluginMap map[string]plugin.Plugin + notificationPluginByName map[string]protobufs.NotifierServer + watcher PluginWatcher + pluginKillMethods []func() + pluginProcConfig *csconfig.PluginCfg + pluginsTypesToDispatch map[string]struct{} + newBackoff backoffFactory } // holder to determine where to dispatch config and how to format messages @@ -127,11 +126,22 @@ func (pb *PluginBroker) Kill() { } } -func (pb *PluginBroker) Run(pluginTomb *tomb.Tomb) { +func (pb *PluginBroker) popPluginAlerts(pluginName string) []*models.Alert { + pluginMutex.Lock() + log.Tracef("going to deliver %d alerts to plugin %s", len(pb.alertsByPluginName[pluginName]), pluginName) + tmpAlerts := pb.alertsByPluginName[pluginName] + pb.alertsByPluginName[pluginName] = make([]*models.Alert, 0) + pluginMutex.Unlock() + + return tmpAlerts +} + +func (pb *PluginBroker) Run(ctx context.Context) { // we get signaled via the channel when notifications need to be delivered to plugin (via the watcher) - ctx := context.TODO() + notifyCtx := ctx - pb.watcher.Start(&tomb.Tomb{}) + watcherCtx, cancelWatcher := context.WithCancel(ctx) + watcherDone := pb.watcher.Start(watcherCtx) for { select { @@ -139,47 +149,45 @@ func (pb *PluginBroker) Run(pluginTomb *tomb.Tomb) { pb.addProfileAlert(profileAlert) case pluginName := <-pb.watcher.PluginEvents: - // this can be run in goroutine, but then locks will be needed - pluginMutex.Lock() - log.Tracef("going to deliver %d alerts to plugin %s", len(pb.alertsByPluginName[pluginName]), pluginName) - tmpAlerts := pb.alertsByPluginName[pluginName] - pb.alertsByPluginName[pluginName] = make([]*models.Alert, 0) - pluginMutex.Unlock() - - go func() { + tmpAlerts := pb.popPluginAlerts(pluginName) + + go func(pluginName string, alerts []*models.Alert) { // Chunk alerts to respect group_threshold threshold := pb.pluginConfigByName[pluginName].GroupThreshold if threshold == 0 { threshold = 1 } - for _, chunk := range slicetools.Chunks(tmpAlerts, threshold) { - if err := pb.pushNotificationsToPlugin(ctx, pluginName, chunk); err != nil { + for _, chunk := range slicetools.Chunks(alerts, threshold) { + if err := pb.pushNotificationsToPlugin(notifyCtx, pluginName, chunk); err != nil { log.WithField("plugin:", pluginName).Error(err) } } - }() + }(pluginName, tmpAlerts) - case <-pluginTomb.Dying(): - log.Infof("pluginTomb dying") - pb.watcher.tomb.Kill(errors.New("Terminating")) + case <-ctx.Done(): + log.Infof("plugin context canceled") + cancelWatcher() + flushCtx, cancelFlush := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) for { select { - case <-pb.watcher.tomb.Dead(): + case <-watcherDone: + cancelFlush() log.Info("killing all plugins") pb.Kill() + return + case <-flushCtx.Done(): + cancelFlush() + log.Warn("plugin flush timeout reached, stopping pending notifications") + pb.Kill() + return case pluginName := <-pb.watcher.PluginEvents: - // this can be run in goroutine, but then locks will be needed - pluginMutex.Lock() - log.Tracef("going to deliver %d alerts to plugin %s", len(pb.alertsByPluginName[pluginName]), pluginName) - tmpAlerts := pb.alertsByPluginName[pluginName] - pb.alertsByPluginName[pluginName] = make([]*models.Alert, 0) - pluginMutex.Unlock() - - if err := pb.pushNotificationsToPlugin(ctx, pluginName, tmpAlerts); err != nil { + tmpAlerts := pb.popPluginAlerts(pluginName) + + if err := pb.pushNotificationsToPlugin(flushCtx, pluginName, tmpAlerts); err != nil { log.WithField("plugin:", pluginName).Error(err) } } diff --git a/pkg/csplugin/broker_test.go b/pkg/csplugin/broker_test.go index 07fb1e11c13..04604f658de 100644 --- a/pkg/csplugin/broker_test.go +++ b/pkg/csplugin/broker_test.go @@ -4,6 +4,7 @@ package csplugin import ( "bytes" + "context" "encoding/json" "io" "os" @@ -13,7 +14,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "gopkg.in/yaml.v3" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -154,8 +154,9 @@ func (s *PluginSuite) TestBrokerNoThreshold() { pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) - tomb := tomb.Tomb{} - go pb.Run(&tomb) + brokerCtx, cancelBroker := context.WithCancel(ctx) + defer cancelBroker() + go pb.Run(brokerCtx) // send one item, it should be processed right now pb.PluginChannel <- models.ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} @@ -206,8 +207,9 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) - tomb := tomb.Tomb{} - go pb.Run(&tomb) + brokerCtx, cancelBroker := context.WithCancel(ctx) + defer cancelBroker() + go pb.Run(brokerCtx) // send data pb.PluginChannel <- models.ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} @@ -243,8 +245,9 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) - tomb := tomb.Tomb{} - go pb.Run(&tomb) + brokerCtx, cancelBroker := context.WithCancel(ctx) + defer cancelBroker() + go pb.Run(brokerCtx) // send data pb.PluginChannel <- models.ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} @@ -284,8 +287,9 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() { pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) - tomb := tomb.Tomb{} - go pb.Run(&tomb) + brokerCtx, cancelBroker := context.WithCancel(ctx) + defer cancelBroker() + go pb.Run(brokerCtx) // send data pb.PluginChannel <- models.ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} @@ -338,8 +342,9 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() { pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) - tomb := tomb.Tomb{} - go pb.Run(&tomb) + brokerCtx, cancelBroker := context.WithCancel(ctx) + defer cancelBroker() + go pb.Run(brokerCtx) // send data pb.PluginChannel <- models.ProfileAlert{ProfileID: uint(0), Alert: &models.Alert{}} @@ -369,8 +374,9 @@ func (s *PluginSuite) TestBrokerRunSimple() { pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) - tomb := tomb.Tomb{} - go pb.Run(&tomb) + brokerCtx, cancelBroker := context.WithCancel(ctx) + defer cancelBroker() + go pb.Run(brokerCtx) assert.NoFileExists(t, s.outFile) diff --git a/pkg/csplugin/broker_win_test.go b/pkg/csplugin/broker_win_test.go index 351ca2ac7d4..fc065eeb2b2 100644 --- a/pkg/csplugin/broker_win_test.go +++ b/pkg/csplugin/broker_win_test.go @@ -4,6 +4,7 @@ package csplugin import ( "bytes" + "context" "encoding/json" "io" "os" @@ -12,7 +13,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" @@ -73,8 +73,9 @@ func (s *PluginSuite) TestBrokerRun() { pb, err := s.InitBroker(ctx, nil) require.NoError(t, err) - tomb := tomb.Tomb{} - go pb.Run(&tomb) + brokerCtx, cancelBroker := context.WithCancel(ctx) + defer cancelBroker() + go pb.Run(brokerCtx) assert.NoFileExists(t, s.outFile) defer os.Remove(s.outFile) diff --git a/pkg/csplugin/watcher.go b/pkg/csplugin/watcher.go index 6a5d31909d5..c1ed3db8ac8 100644 --- a/pkg/csplugin/watcher.go +++ b/pkg/csplugin/watcher.go @@ -1,11 +1,11 @@ package csplugin import ( + "context" "sync" "time" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/models" ) @@ -50,7 +50,8 @@ type PluginWatcher struct { AlertCountByPluginName alertCounterByPluginName PluginEvents chan string Inserts chan string - tomb *tomb.Tomb + wg sync.WaitGroup + done chan struct{} } var DefaultEmptyTicker = time.Second * 1 @@ -65,23 +66,33 @@ func (pw *PluginWatcher) Init(configs map[string]PluginConfig, alertsByPluginNam } } -func (pw *PluginWatcher) Start(tomb *tomb.Tomb) { - pw.tomb = tomb +func (pw *PluginWatcher) Start(ctx context.Context) <-chan struct{} { + pw.done = make(chan struct{}) + for name := range pw.PluginConfigByName { pname := name - pw.tomb.Go(func() error { - pw.watchPluginTicker(pname) - return nil - }) + pw.wg.Add(1) + go func() { + defer pw.wg.Done() + pw.watchPluginTicker(ctx, pname) + }() } - pw.tomb.Go(func() error { - pw.watchPluginAlertCounts() - return nil - }) + pw.wg.Add(1) + go func() { + defer pw.wg.Done() + pw.watchPluginAlertCounts(ctx) + }() + + go func() { + pw.wg.Wait() + close(pw.done) + }() + + return pw.done } -func (pw *PluginWatcher) watchPluginTicker(pluginName string) { +func (pw *PluginWatcher) watchPluginTicker(ctx context.Context, pluginName string) { cfg := pw.PluginConfigByName[pluginName] interval := cfg.GroupWait threshold := cfg.GroupThreshold @@ -142,15 +153,15 @@ func (pw *PluginWatcher) watchPluginTicker(pluginName string) { lastSend = now pw.AlertCountByPluginName.Set(pluginName, 0) pw.PluginEvents <- pluginName - case <-pw.tomb.Dying(): - // no lock here because we have the broker still listening even in dying state before killing us + case <-ctx.Done(): + // no lock here because the broker is still listening during shutdown. pw.PluginEvents <- pluginName return } } } -func (pw *PluginWatcher) watchPluginAlertCounts() { +func (pw *PluginWatcher) watchPluginAlertCounts(ctx context.Context) { for { select { case pluginName := <-pw.Inserts: @@ -160,7 +171,7 @@ func (pw *PluginWatcher) watchPluginAlertCounts() { curr = 0 } pw.AlertCountByPluginName.Set(pluginName, curr+1) - case <-pw.tomb.Dying(): + case <-ctx.Done(): return } } diff --git a/pkg/csplugin/watcher_test.go b/pkg/csplugin/watcher_test.go index 4a3a1bacea3..6f1efe59dd6 100644 --- a/pkg/csplugin/watcher_test.go +++ b/pkg/csplugin/watcher_test.go @@ -7,19 +7,26 @@ import ( "time" "github.com/stretchr/testify/require" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/crowdsec/pkg/models" ) -func resetTestTomb(t *testing.T, testTomb *tomb.Tomb, pw *PluginWatcher) { - testTomb.Kill(nil) - <-pw.PluginEvents +func stopWatcher(t *testing.T, cancel context.CancelFunc, pw *PluginWatcher, done <-chan struct{}) { + cancel() - err := testTomb.Wait() - require.NoError(t, err) + select { + case <-pw.PluginEvents: + case <-time.After(time.Second): + require.FailNow(t, "timeout waiting for watcher flush event") + } + + select { + case <-done: + case <-time.After(time.Second): + require.FailNow(t, "timeout waiting for watcher shutdown") + } } func resetWatcherAlertCounter(pw *PluginWatcher) { @@ -54,23 +61,24 @@ func TestPluginWatcherInterval(t *testing.T) { pw := PluginWatcher{} alertsByPluginName := make(map[string][]*models.Alert) - testTomb := tomb.Tomb{} + watcherCtx, cancelWatcher := context.WithCancel(ctx) configs := map[string]PluginConfig{ "testPlugin": { GroupWait: time.Millisecond, }, } pw.Init(configs, alertsByPluginName) - pw.Start(&testTomb) + watcherDone := pw.Start(watcherCtx) ct, cancel := context.WithTimeout(ctx, time.Microsecond) defer cancel() err := listenChannelWithTimeout(ct, pw.PluginEvents) cstest.RequireErrorContains(t, err, "context deadline exceeded") - resetTestTomb(t, &testTomb, &pw) - testTomb = tomb.Tomb{} - pw.Start(&testTomb) + stopWatcher(t, cancelWatcher, &pw, watcherDone) + + watcherCtx, cancelWatcher = context.WithCancel(ctx) + watcherDone = pw.Start(watcherCtx) insertNAlertsToPlugin(&pw, 1, "testPlugin") ct, cancel = context.WithTimeout(ctx, time.Millisecond*5) @@ -78,7 +86,7 @@ func TestPluginWatcherInterval(t *testing.T) { err = listenChannelWithTimeout(ct, pw.PluginEvents) require.NoError(t, err) - resetTestTomb(t, &testTomb, &pw) + stopWatcher(t, cancelWatcher, &pw, watcherDone) // This is to avoid the int complaining } @@ -94,10 +102,10 @@ func TestPluginAlertCountWatcher(t *testing.T) { GroupThreshold: 5, }, } - testTomb := tomb.Tomb{} + watcherCtx, cancelWatcher := context.WithCancel(ctx) pw.Init(configs, alertsByPluginName) - pw.Start(&testTomb) + watcherDone := pw.Start(watcherCtx) // Channel won't contain any events since threshold is not crossed. ct, cancel := context.WithTimeout(ctx, time.Second) @@ -125,5 +133,5 @@ func TestPluginAlertCountWatcher(t *testing.T) { err = listenChannelWithTimeout(ct, pw.PluginEvents) require.NoError(t, err) - resetTestTomb(t, &testTomb, &pw) + stopWatcher(t, cancelWatcher, &pw, watcherDone) } diff --git a/pkg/longpollclient/client.go b/pkg/longpollclient/client.go index af4665ff1f0..3b538609558 100644 --- a/pkg/longpollclient/client.go +++ b/pkg/longpollclient/client.go @@ -8,17 +8,19 @@ import ( "io" "net/http" "net/url" + "sync" "time" "github.com/google/uuid" log "github.com/sirupsen/logrus" - "gopkg.in/tomb.v2" "github.com/crowdsecurity/crowdsec/pkg/logging" ) type LongPollClient struct { - t tomb.Tomb + mu sync.Mutex + cancelPoll context.CancelFunc + done chan struct{} c chan Event url url.URL logger *log.Entry @@ -89,8 +91,8 @@ func (c *LongPollClient) poll(ctx context.Context) error { defer resp.Body.Close() - requestId := resp.Header.Get("X-Amzn-Trace-Id") - logger = logger.WithField("request-id", requestId) + requestID := resp.Header.Get("X-Amzn-Trace-Id") + logger = logger.WithField("request-id", requestID) if resp.StatusCode != http.StatusOK { c.logger.Errorf("unexpected status code: %d", resp.StatusCode) if resp.StatusCode == http.StatusPaymentRequired { @@ -109,20 +111,15 @@ func (c *LongPollClient) poll(ctx context.Context) error { for { select { - case <-c.t.Dying(): - logger.Debugf("dying") - close(c.c) - return nil case <-ctx.Done(): - logger.Debugf("context canceled") - close(c.c) + logger.Debug("context canceled") return ctx.Err() default: var pollResp pollResponse err = decoder.Decode(&pollResp) if err != nil { if errors.Is(err, io.EOF) { - logger.Debugf("server closed connection") + logger.Debug("server closed connection") return nil } return fmt.Errorf("error decoding poll response: %v", err) @@ -132,7 +129,7 @@ func (c *LongPollClient) poll(ctx context.Context) error { if pollResp.ErrorMessage != "" { if pollResp.ErrorMessage == timeoutMessage { - logger.Debugf("got timeout message") + logger.Debug("got timeout message") return nil } return fmt.Errorf("longpoll API error message: %s", pollResp.ErrorMessage) @@ -141,8 +138,12 @@ func (c *LongPollClient) poll(ctx context.Context) error { if len(pollResp.Events) > 0 { logger.Debugf("got %d events", len(pollResp.Events)) for _, event := range pollResp.Events { - event.RequestId = requestId - c.c <- event + event.RequestId = requestID + select { + case <-ctx.Done(): + return ctx.Err() + case c.c <- event: + } if event.Timestamp > c.since { c.since = event.Timestamp } @@ -157,27 +158,21 @@ func (c *LongPollClient) poll(ctx context.Context) error { } func (c *LongPollClient) pollEvents(ctx context.Context) error { - initialBackoff := 1 * time.Second maxBackoff := 30 * time.Second currentBackoff := initialBackoff for { select { - case <-c.t.Dying(): - c.logger.Debug("dying") - return nil case <-ctx.Done(): c.logger.Debug("context canceled") - return ctx.Err() + return nil default: c.logger.Debug("Polling PAPI") err := c.poll(ctx) if err != nil { if errors.Is(err, errUnauthorized) { c.logger.Errorf("unauthorized, stopping polling") - c.t.Kill(err) - close(c.c) return err } if errors.Is(err, context.Canceled) { @@ -186,10 +181,8 @@ func (c *LongPollClient) pollEvents(ctx context.Context) error { } c.logger.Errorf("failed to poll: %s, retrying in %s", err, currentBackoff) select { - case <-c.t.Dying(): - return nil case <-ctx.Done(): - return ctx.Err() + return nil case <-time.After(currentBackoff): } @@ -205,17 +198,47 @@ func (c *LongPollClient) pollEvents(ctx context.Context) error { } func (c *LongPollClient) Start(ctx context.Context, since time.Time) chan Event { - c.logger.Infof("starting polling client") - c.c = make(chan Event) + _ = c.Stop() + + c.logger.Info("starting polling client") + + pollCtx, cancel := context.WithCancel(ctx) + out := make(chan Event) + done := make(chan struct{}) + + c.mu.Lock() + c.c = out c.since = since.Unix() * 1000 c.timeout = "45" - c.t = tomb.Tomb{} - c.t.Go(func() error { return c.pollEvents(ctx) }) - return c.c + c.cancelPoll = cancel + c.done = done + c.mu.Unlock() + + go func() { + defer close(done) + defer close(out) + _ = c.pollEvents(pollCtx) + }() + + return out } func (c *LongPollClient) Stop() error { - c.t.Kill(nil) + c.mu.Lock() + cancel := c.cancelPoll + done := c.done + c.cancelPoll = nil + c.done = nil + c.mu.Unlock() + + if cancel != nil { + cancel() + } + + if done != nil { + <-done + } + return nil } @@ -235,7 +258,7 @@ func (c *LongPollClient) PullOnce(ctx context.Context, since time.Time) ([]Event err = decoder.Decode(&pollResp) if err != nil { if errors.Is(err, io.EOF) { - c.logger.Debugf("server closed connection") + c.logger.Debug("server closed connection") break } log.Errorf("error decoding poll response: %v", err) @@ -246,7 +269,7 @@ func (c *LongPollClient) PullOnce(ctx context.Context, since time.Time) ([]Event if pollResp.ErrorMessage != "" { if pollResp.ErrorMessage == timeoutMessage { - c.logger.Debugf("got timeout message") + c.logger.Debug("got timeout message") break } log.Errorf("longpoll API error message: %s", pollResp.ErrorMessage) diff --git a/test/bats/00_wait_for.bats b/test/bats/00_wait_for.bats index b8530602cce..b44bc492c29 100644 --- a/test/bats/00_wait_for.bats +++ b/test/bats/00_wait_for.bats @@ -67,3 +67,12 @@ setup() { 2 EOT } + +@test "run a command with timeout (match, slow shutdown)" { + # even if the process doesn't stop quickly, a matched pattern must not become timeout=241 + rune -128 wait-for --timeout .2 --out "2" sh -c 'trap "" TERM; echo 1; echo 2; sleep 5' + assert_output - <<-EOT + 1 + 2 + EOT +} diff --git a/test/bin/wait-for b/test/bin/wait-for index 798e2a09b34..7bbcc97a2cc 100755 --- a/test/bin/wait-for +++ b/test/bin/wait-for @@ -15,12 +15,12 @@ DEFAULT_TIMEOUT = 30 # TODO: print unmatched patterns -async def terminate_group(p: asyncio.subprocess.Process): +async def terminate_group(p: asyncio.subprocess.Process, sig: signal.Signals = signal.SIGTERM): """ Terminate the process group (shell, crowdsec plugins) """ try: - os.killpg(os.getpgid(p.pid), signal.SIGTERM) + os.killpg(os.getpgid(p.pid), sig) except ProcessLookupError: pass @@ -60,6 +60,15 @@ async def monitor( out.write(line) if pattern and pattern.search(line): await terminate_group(process) + + # Ensure the process exits even if it ignores SIGTERM + async def _force_kill(): + await asyncio.sleep(2) + if process.returncode is None: + await terminate_group(process, signal.SIGKILL) + + asyncio.create_task(_force_kill()) + # this is nasty. # if we timeout, we want to return a different exit code # in case of a match, so that the caller can tell @@ -97,8 +106,13 @@ async def monitor( if status is None: status = process.returncode except asyncio.TimeoutError: + if status is None: + status = 241 await terminate_group(process) - status = 241 + try: + await asyncio.wait_for(process.wait(), timeout=0.2) + except asyncio.TimeoutError: + await terminate_group(process, signal.SIGKILL) # Return the same exit code, stdout and stderr as the spawned process return status or 0 diff --git a/test/lib/init/crowdsec-daemon b/test/lib/init/crowdsec-daemon index ba8e98992db..803187567ed 100755 --- a/test/lib/init/crowdsec-daemon +++ b/test/lib/init/crowdsec-daemon @@ -53,9 +53,18 @@ stop() { if [[ -n "${PGID}" ]]; then kill -- "-${PGID}" - while pgrep -g "${PGID}" >/dev/null; do + # wait up to 5s for graceful shutdown, then escalate to SIGKILL + for _ in $(seq 100); do + pgrep -g "${PGID}" >/dev/null || break sleep .05 done + + if pgrep -g "${PGID}" >/dev/null 2>&1; then + kill -9 -- "-${PGID}" 2>/dev/null || true + while pgrep -g "${PGID}" >/dev/null 2>&1; do + sleep .05 + done + fi fi rm -f -- "${DAEMON_PID}"