From 20fde29f077299e8256fa85aa4dadfabee5cebc0 Mon Sep 17 00:00:00 2001 From: Nikola Jokic Date: Tue, 17 Feb 2026 15:22:09 +0100 Subject: [PATCH] wip --- cmd/ghalistener/app/app.go | 73 +- cmd/ghalistener/config/config_client_test.go | 86 +- cmd/ghalistener/listener/listener.go | 459 --------- cmd/ghalistener/listener/listener_test.go | 970 ------------------ cmd/ghalistener/listener/metrics_test.go | 205 ---- cmd/ghalistener/listener/mocks/client.go | 190 ---- cmd/ghalistener/listener/mocks/handler.go | 68 -- cmd/ghalistener/metrics/metrics.go | 43 +- .../metrics/metrics_integration_test.go | 12 +- cmd/ghalistener/metrics/metrics_test.go | 16 +- cmd/ghalistener/worker/worker.go | 86 +- cmd/ghalistener/worker/worker_test.go | 235 +++-- go.mod | 2 +- go.sum | 2 + 14 files changed, 303 insertions(+), 2144 deletions(-) delete mode 100644 cmd/ghalistener/listener/listener.go delete mode 100644 cmd/ghalistener/listener/listener_test.go delete mode 100644 cmd/ghalistener/listener/metrics_test.go delete mode 100644 cmd/ghalistener/listener/mocks/client.go delete mode 100644 cmd/ghalistener/listener/mocks/handler.go diff --git a/cmd/ghalistener/app/app.go b/cmd/ghalistener/app/app.go index f0318bb9..fd5d0d72 100644 --- a/cmd/ghalistener/app/app.go +++ b/cmd/ghalistener/app/app.go @@ -7,11 +7,11 @@ import ( "os" "github.com/actions/actions-runner-controller/cmd/ghalistener/config" - "github.com/actions/actions-runner-controller/cmd/ghalistener/listener" "github.com/actions/actions-runner-controller/cmd/ghalistener/metrics" "github.com/actions/actions-runner-controller/cmd/ghalistener/worker" "github.com/actions/actions-runner-controller/github/actions" "github.com/actions/scaleset" + "github.com/actions/scaleset/listener" "github.com/google/uuid" "golang.org/x/sync/errgroup" ) @@ -29,17 +29,6 @@ type App struct { metrics metrics.ServerExporter } -//go:generate mockery --name Listener --output ./mocks --outpkg mocks --case underscore -type Listener interface { - Listen(ctx context.Context, handler listener.Handler) error -} - -//go:generate mockery --name Worker --output ./mocks --outpkg mocks --case underscore -type Worker interface { - HandleJobStarted(ctx context.Context, jobInfo *actions.JobStarted) error - HandleDesiredRunnerCount(ctx context.Context, count int, jobsCompleted int) (int, error) -} - func New(config config.Config) (*App, error) { if err := config.Validate(); err != nil { return nil, fmt.Errorf("failed to validate config: %w", err) @@ -102,14 +91,28 @@ func (app *App) Run(ctx context.Context) error { } defer sessionClient.Close(context.Background()) - listener, err := listener.New(listener.Config{ - Client: sessionClient, - ScaleSetID: app.config.RunnerScaleSetID, - MinRunners: app.config.MinRunners, - MaxRunners: app.config.MaxRunners, - Logger: *app.logger.With("component", "listener"), - Metrics: app.metrics, - }) + hasMetrics := app.metrics != nil + + var listenerOptions []listener.Option + if hasMetrics { + listenerOptions = append( + listenerOptions, + listener.WithMetricsRecorder( + &metricsRecorder{metrics: app.metrics}, + ), + ) + app.metrics.PublishStatic(app.config.MinRunners, app.config.MaxRunners) + } + + listener, err := listener.New( + sessionClient, + listener.Config{ + ScaleSetID: app.config.RunnerScaleSetID, + MaxRunners: app.config.MaxRunners, + Logger: app.logger.With("component", "listener"), + }, + listenerOptions..., + ) if err != nil { return fmt.Errorf("failed to create new listener: %w", err) } @@ -132,12 +135,12 @@ func (app *App) Run(ctx context.Context) error { g.Go(func() error { app.logger.Info("Starting listener") - listnerErr := listener.Listen(ctx, worker) + listnerErr := listener.Run(ctx, worker) cancelMetrics(fmt.Errorf("Listener exited: %w", listnerErr)) return listnerErr }) - if app.metrics != nil { + if hasMetrics { g.Go(func() error { app.logger.Info("Starting metrics server") return app.metrics.ListenAndServe(metricsCtx) @@ -146,3 +149,29 @@ func (app *App) Run(ctx context.Context) error { return g.Wait() } + +var _ listener.MetricsRecorder = (*metricsRecorder)(nil) + +type metricsRecorder struct { + metrics metrics.Publisher // The publisher used to publish metrics. +} + +// RecordDesiredRunners implements [listener.MetricsRecorder]. +func (m *metricsRecorder) RecordDesiredRunners(count int) { + m.metrics.PublishDesiredRunners(count) +} + +// RecordJobCompleted implements [listener.MetricsRecorder]. +func (m *metricsRecorder) RecordJobCompleted(msg *scaleset.JobCompleted) { + m.metrics.PublishJobCompleted(msg) +} + +// RecordJobStarted implements [listener.MetricsRecorder]. +func (m *metricsRecorder) RecordJobStarted(msg *scaleset.JobStarted) { + m.metrics.PublishJobStarted(msg) +} + +// RecordStatistics implements [listener.MetricsRecorder]. +func (m *metricsRecorder) RecordStatistics(statistics *scaleset.RunnerScaleSetStatistic) { + m.metrics.PublishStatistics(statistics) +} diff --git a/cmd/ghalistener/config/config_client_test.go b/cmd/ghalistener/config/config_client_test.go index 4fc37341..d7332684 100644 --- a/cmd/ghalistener/config/config_client_test.go +++ b/cmd/ghalistener/config/config_client_test.go @@ -3,21 +3,23 @@ package config_test import ( "context" "crypto/tls" + "encoding/json" + "log/slog" "net/http" - "net/http/httptest" "os" "path/filepath" "testing" "github.com/actions/actions-runner-controller/apis/actions.github.com/v1alpha1/appconfig" "github.com/actions/actions-runner-controller/cmd/ghalistener/config" - "github.com/actions/actions-runner-controller/github/actions" "github.com/actions/actions-runner-controller/github/actions/testserver" - "github.com/go-logr/logr" + "github.com/actions/scaleset" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +var discardLogger = slog.New(slog.DiscardHandler) + func TestCustomerServerRootCA(t *testing.T) { ctx := context.Background() certsFolder := filepath.Join( @@ -59,7 +61,7 @@ func TestCustomerServerRootCA(t *testing.T) { }, } - client, err := config.ActionsClient(logr.Discard()) + client, err := config.ActionsClient(*discardLogger) require.NoError(t, err) _, err = client.GetRunnerScaleSet(ctx, 1, "test") require.NoError(t, err) @@ -67,18 +69,19 @@ func TestCustomerServerRootCA(t *testing.T) { } func TestProxySettings(t *testing.T) { + assertHasProxy := func(t *testing.T, debugInfo string, want bool) { + type debugInfoContent struct { + HasProxy bool `json:"has_proxy"` + } + var got debugInfoContent + err := json.Unmarshal([]byte(debugInfo), &got) + require.NoError(t, err) + assert.Equal(t, want, got.HasProxy) + } + t.Run("http", func(t *testing.T) { - wentThroughProxy := false - - proxy := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) { - wentThroughProxy = true - })) - t.Cleanup(func() { - proxy.Close() - }) - prevProxy := os.Getenv("http_proxy") - os.Setenv("http_proxy", proxy.URL) + os.Setenv("http_proxy", "http://proxy:8080") defer os.Setenv("http_proxy", prevProxy) config := config.Config{ @@ -88,29 +91,15 @@ func TestProxySettings(t *testing.T) { }, } - client, err := config.ActionsClient(logr.Discard()) + client, err := config.ActionsClient(*discardLogger) require.NoError(t, err) - req, err := http.NewRequest(http.MethodGet, "http://example.com", nil) - require.NoError(t, err) - _, err = client.Do(req) - require.NoError(t, err) - - assert.True(t, wentThroughProxy) + assertHasProxy(t, client.DebugInfo(), true) }) t.Run("https", func(t *testing.T) { - wentThroughProxy := false - - proxy := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) { - wentThroughProxy = true - })) - t.Cleanup(func() { - proxy.Close() - }) - prevProxy := os.Getenv("https_proxy") - os.Setenv("https_proxy", proxy.URL) + os.Setenv("https_proxy", "https://proxy:443") defer os.Setenv("https_proxy", prevProxy) config := config.Config{ @@ -120,32 +109,16 @@ func TestProxySettings(t *testing.T) { }, } - client, err := config.ActionsClient(logr.Discard(), actions.WithRetryMax(0)) + client, err := config.ActionsClient( + *discardLogger, + scaleset.WithRetryMax(0), + ) require.NoError(t, err) - req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) - require.NoError(t, err) - - _, err = client.Do(req) - // proxy doesn't support https - assert.Error(t, err) - assert.True(t, wentThroughProxy) + assertHasProxy(t, client.DebugInfo(), true) }) t.Run("no_proxy", func(t *testing.T) { - wentThroughProxy := false - - proxy := httptest.NewServer(http.HandlerFunc(func(http.ResponseWriter, *http.Request) { - wentThroughProxy = true - })) - t.Cleanup(func() { - proxy.Close() - }) - - prevProxy := os.Getenv("http_proxy") - os.Setenv("http_proxy", proxy.URL) - defer os.Setenv("http_proxy", prevProxy) - prevNoProxy := os.Getenv("no_proxy") os.Setenv("no_proxy", "example.com") defer os.Setenv("no_proxy", prevNoProxy) @@ -157,14 +130,9 @@ func TestProxySettings(t *testing.T) { }, } - client, err := config.ActionsClient(logr.Discard()) + client, err := config.ActionsClient(*discardLogger) require.NoError(t, err) - req, err := http.NewRequest(http.MethodGet, "http://example.com", nil) - require.NoError(t, err) - - _, err = client.Do(req) - require.NoError(t, err) - assert.False(t, wentThroughProxy) + assertHasProxy(t, client.DebugInfo(), true) }) } diff --git a/cmd/ghalistener/listener/listener.go b/cmd/ghalistener/listener/listener.go deleted file mode 100644 index c530fe8c..00000000 --- a/cmd/ghalistener/listener/listener.go +++ /dev/null @@ -1,459 +0,0 @@ -package listener - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "log/slog" - "net/http" - "os" - "time" - - "github.com/actions/actions-runner-controller/cmd/ghalistener/metrics" - "github.com/actions/actions-runner-controller/github/actions" - "github.com/actions/scaleset" - "github.com/go-logr/logr" - "github.com/google/uuid" -) - -const ( - sessionCreationMaxRetries = 10 -) - -// message types -const ( - messageTypeJobAvailable = "JobAvailable" - messageTypeJobAssigned = "JobAssigned" - messageTypeJobStarted = "JobStarted" - messageTypeJobCompleted = "JobCompleted" -) - -// Client defines the interface for communicating with the scaleset API. -// In most cases, it should be scaleset.Client from the scaleset package. -// This interface is defined to allow for easier testing and mocking, as well -// as allowing wrappers around the scaleset client if needed. -type Client interface { - GetMessage(ctx context.Context, lastMessageID, maxCapacity int) (*scaleset.RunnerScaleSetMessage, error) - DeleteMessage(ctx context.Context, messageID int) error - Session() scaleset.RunnerScaleSetSession -} - -type Config struct { - Client Client - ScaleSetID int - MinRunners int - MaxRunners int - Logger slog.Logger - Metrics metrics.Publisher -} - -func (c *Config) Validate() error { - if c.Client == nil { - return errors.New("client is required") - } - if c.ScaleSetID == 0 { - return errors.New("scaleSetID is required") - } - if c.MinRunners < 0 { - return errors.New("minRunners must be greater than or equal to 0") - } - if c.MaxRunners < 0 { - return errors.New("maxRunners must be greater than or equal to 0") - } - if c.MaxRunners > 0 && c.MinRunners > c.MaxRunners { - return errors.New("minRunners must be less than or equal to maxRunners") - } - return nil -} - -// The Listener's role is to manage all interactions with the actions service. -// It receives messages and processes them using the given handler. -type Listener struct { - // configured fields - scaleSetID int // The ID of the scale set associated with the listener. - client Client // The client used to interact with the scale set. - metrics metrics.Publisher // The publisher used to publish metrics. - - // internal fields - logger logr.Logger // The logger used for logging. - hostname string // The hostname of the listener. - - // updated fields - lastMessageID int64 // The ID of the last processed message. - maxCapacity int // The maximum number of runners that can be created. - session *actions.RunnerScaleSetSession // The session for managing the runner scale set. -} - -func New(config Config) (*Listener, error) { - if err := config.Validate(); err != nil { - return nil, fmt.Errorf("invalid config: %w", err) - } - - listener := &Listener{ - scaleSetID: config.ScaleSetID, - client: config.Client, - logger: config.Logger, - metrics: metrics.Discard, - maxCapacity: config.MaxRunners, - } - - if config.Metrics != nil { - listener.metrics = config.Metrics - } - - listener.metrics.PublishStatic(config.MinRunners, config.MaxRunners) - - hostname, err := os.Hostname() - if err != nil { - hostname = uuid.NewString() - listener.logger.Info("Failed to get hostname, fallback to uuid", "uuid", hostname, "error", err) - } - listener.hostname = hostname - - return listener, nil -} - -//go:generate mockery --name Handler --output ./mocks --outpkg mocks --case underscore -type Handler interface { - HandleJobStarted(ctx context.Context, jobInfo *actions.JobStarted) error - HandleDesiredRunnerCount(ctx context.Context, count, jobsCompleted int) (int, error) -} - -// Listen listens for incoming messages and handles them using the provided handler. -// It continuously listens for messages until the context is cancelled. -// The initial message contains the current statistics and acquirable jobs, if any. -// The handler is responsible for handling the initial message and subsequent messages. -// If an error occurs during any step, Listen returns an error. -func (l *Listener) Listen(ctx context.Context, handler Handler) error { - if err := l.createSession(ctx); err != nil { - return fmt.Errorf("createSession failed: %w", err) - } - - defer func() { - if err := l.deleteMessageSession(); err != nil { - l.logger.Error(err, "failed to delete message session") - } - }() - - initialMessage := &actions.RunnerScaleSetMessage{ - MessageId: 0, - MessageType: "RunnerScaleSetJobMessages", - Statistics: l.session.Statistics, - Body: "", - } - - if l.session.Statistics == nil { - return fmt.Errorf("session statistics is nil") - } - l.metrics.PublishStatistics(initialMessage.Statistics) - - desiredRunners, err := handler.HandleDesiredRunnerCount(ctx, initialMessage.Statistics.TotalAssignedJobs, 0) - if err != nil { - return fmt.Errorf("handling initial message failed: %w", err) - } - l.metrics.PublishDesiredRunners(desiredRunners) - - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - msg, err := l.getMessage(ctx) - if err != nil { - return fmt.Errorf("failed to get message: %w", err) - } - - if msg == nil { - _, err := handler.HandleDesiredRunnerCount(ctx, 0, 0) - if err != nil { - return fmt.Errorf("handling nil message failed: %w", err) - } - - continue - } - - // Remove cancellation from the context to avoid cancelling the message handling. - if err := l.handleMessage(context.WithoutCancel(ctx), handler, msg); err != nil { - return fmt.Errorf("failed to handle message: %w", err) - } - } -} - -func (l *Listener) handleMessage(ctx context.Context, handler Handler, msg *actions.RunnerScaleSetMessage) error { - parsedMsg, err := l.parseMessage(ctx, msg) - if err != nil { - return fmt.Errorf("failed to parse message: %w", err) - } - l.metrics.PublishStatistics(parsedMsg.statistics) - - if len(parsedMsg.jobsAvailable) > 0 { - acquiredJobIDs, err := l.acquireAvailableJobs(ctx, parsedMsg.jobsAvailable) - if err != nil { - return fmt.Errorf("failed to acquire jobs: %w", err) - } - - l.logger.Info("Jobs are acquired", "count", len(acquiredJobIDs), "requestIds", fmt.Sprint(acquiredJobIDs)) - } - - for _, jobCompleted := range parsedMsg.jobsCompleted { - l.metrics.PublishJobCompleted(jobCompleted) - } - - l.lastMessageID = msg.MessageId - - if err := l.deleteLastMessage(ctx); err != nil { - return fmt.Errorf("failed to delete message: %w", err) - } - - for _, jobStarted := range parsedMsg.jobsStarted { - if err := handler.HandleJobStarted(ctx, jobStarted); err != nil { - return fmt.Errorf("failed to handle job started: %w", err) - } - l.metrics.PublishJobStarted(jobStarted) - } - - desiredRunners, err := handler.HandleDesiredRunnerCount(ctx, parsedMsg.statistics.TotalAssignedJobs, len(parsedMsg.jobsCompleted)) - if err != nil { - return fmt.Errorf("failed to handle desired runner count: %w", err) - } - l.metrics.PublishDesiredRunners(desiredRunners) - return nil -} - -func (l *Listener) createSession(ctx context.Context) error { - var session *actions.RunnerScaleSetSession - var retries int - - for { - var err error - session, err = l.client.CreateMessageSession(ctx, l.scaleSetID, l.hostname) - if err == nil { - break - } - - clientErr := &actions.HttpClientSideError{} - if !errors.As(err, &clientErr) { - return fmt.Errorf("failed to create session: %w", err) - } - - if clientErr.Code != http.StatusConflict { - return fmt.Errorf("failed to create session: %w", err) - } - - retries++ - if retries >= sessionCreationMaxRetries { - return fmt.Errorf("failed to create session after %d retries: %w", retries, err) - } - - l.logger.Info("Unable to create message session. Will try again in 30 seconds", "error", err.Error()) - - select { - case <-ctx.Done(): - return fmt.Errorf("context cancelled: %w", ctx.Err()) - case <-time.After(30 * time.Second): - } - } - - statistics, err := json.Marshal(session.Statistics) - if err != nil { - return fmt.Errorf("failed to marshal statistics: %w", err) - } - l.logger.Info("Current runner scale set statistics.", "statistics", string(statistics)) - - l.session = session - - return nil -} - -func (l *Listener) getMessage(ctx context.Context) (*actions.RunnerScaleSetMessage, error) { - l.logger.Info("Getting next message", "lastMessageID", l.lastMessageID) - msg, err := l.client.GetMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID, l.maxCapacity) - if err == nil { // if NO error - return msg, nil - } - - expiredError := &actions.MessageQueueTokenExpiredError{} - if !errors.As(err, &expiredError) { - return nil, fmt.Errorf("failed to get next message: %w", err) - } - - if err := l.refreshSession(ctx); err != nil { - return nil, err - } - - l.logger.Info("Getting next message", "lastMessageID", l.lastMessageID) - - msg, err = l.client.GetMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID, l.maxCapacity) - if err != nil { // if NO error - return nil, fmt.Errorf("failed to get next message after message session refresh: %w", err) - } - - return msg, nil -} - -func (l *Listener) deleteLastMessage(ctx context.Context) error { - l.logger.Info("Deleting last message", "lastMessageID", l.lastMessageID) - err := l.client.DeleteMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID) - if err == nil { // if NO error - return nil - } - - expiredError := &actions.MessageQueueTokenExpiredError{} - if !errors.As(err, &expiredError) { - return fmt.Errorf("failed to delete last message: %w", err) - } - - if err := l.refreshSession(ctx); err != nil { - return err - } - - err = l.client.DeleteMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID) - if err != nil { - return fmt.Errorf("failed to delete last message after message session refresh: %w", err) - } - - return nil -} - -type parsedMessage struct { - statistics *actions.RunnerScaleSetStatistic - jobsStarted []*actions.JobStarted - jobsAvailable []*actions.JobAvailable - jobsCompleted []*actions.JobCompleted -} - -func (l *Listener) parseMessage(ctx context.Context, msg *actions.RunnerScaleSetMessage) (*parsedMessage, error) { - if msg.MessageType != "RunnerScaleSetJobMessages" { - l.logger.Info("Skipping message", "messageType", msg.MessageType) - return nil, fmt.Errorf("invalid message type: %s", msg.MessageType) - } - - l.logger.Info("Processing message", "messageId", msg.MessageId, "messageType", msg.MessageType) - if msg.Statistics == nil { - return nil, fmt.Errorf("invalid message: statistics is nil") - } - - l.logger.Info("New runner scale set statistics.", "statistics", msg.Statistics) - - var batchedMessages []json.RawMessage - if len(msg.Body) > 0 { - if err := json.Unmarshal([]byte(msg.Body), &batchedMessages); err != nil { - return nil, fmt.Errorf("failed to unmarshal batched messages: %w", err) - } - } - - parsedMsg := &parsedMessage{ - statistics: msg.Statistics, - } - - for _, msg := range batchedMessages { - var messageType actions.JobMessageType - if err := json.Unmarshal(msg, &messageType); err != nil { - return nil, fmt.Errorf("failed to decode job message type: %w", err) - } - - switch messageType.MessageType { - case messageTypeJobAvailable: - var jobAvailable actions.JobAvailable - if err := json.Unmarshal(msg, &jobAvailable); err != nil { - return nil, fmt.Errorf("failed to decode job available: %w", err) - } - - l.logger.Info("Job available message received", "jobId", jobAvailable.JobID) - parsedMsg.jobsAvailable = append(parsedMsg.jobsAvailable, &jobAvailable) - - case messageTypeJobAssigned: - var jobAssigned actions.JobAssigned - if err := json.Unmarshal(msg, &jobAssigned); err != nil { - return nil, fmt.Errorf("failed to decode job assigned: %w", err) - } - - l.logger.Info("Job assigned message received", "jobId", jobAssigned.JobID) - - case messageTypeJobStarted: - var jobStarted actions.JobStarted - if err := json.Unmarshal(msg, &jobStarted); err != nil { - return nil, fmt.Errorf("could not decode job started message. %w", err) - } - l.logger.Info("Job started message received.", "JobID", jobStarted.JobID, "RunnerId", jobStarted.RunnerID) - parsedMsg.jobsStarted = append(parsedMsg.jobsStarted, &jobStarted) - - case messageTypeJobCompleted: - var jobCompleted actions.JobCompleted - if err := json.Unmarshal(msg, &jobCompleted); err != nil { - return nil, fmt.Errorf("failed to decode job completed: %w", err) - } - - l.logger.Info( - "Job completed message received.", - "JobID", jobCompleted.JobID, - "Result", jobCompleted.Result, - "RunnerId", jobCompleted.RunnerId, - "RunnerName", jobCompleted.RunnerName, - ) - parsedMsg.jobsCompleted = append(parsedMsg.jobsCompleted, &jobCompleted) - - default: - l.logger.Info("unknown job message type.", "messageType", messageType.MessageType) - } - } - - return parsedMsg, nil -} - -func (l *Listener) acquireAvailableJobs(ctx context.Context, jobsAvailable []*actions.JobAvailable) ([]int64, error) { - ids := make([]int64, 0, len(jobsAvailable)) - for _, job := range jobsAvailable { - ids = append(ids, job.RunnerRequestID) - } - - l.logger.Info("Acquiring jobs", "count", len(ids), "requestIds", fmt.Sprint(ids)) - - idsAcquired, err := l.client.AcquireJobs(ctx, l.scaleSetID, l.session.MessageQueueAccessToken, ids) - if err == nil { // if NO errors - return idsAcquired, nil - } - - expiredError := &actions.MessageQueueTokenExpiredError{} - if !errors.As(err, &expiredError) { - return nil, fmt.Errorf("failed to acquire jobs: %w", err) - } - - if err := l.refreshSession(ctx); err != nil { - return nil, err - } - - idsAcquired, err = l.client.AcquireJobs(ctx, l.scaleSetID, l.session.MessageQueueAccessToken, ids) - if err != nil { - return nil, fmt.Errorf("failed to acquire jobs after session refresh: %w", err) - } - - return idsAcquired, nil -} - -func (l *Listener) refreshSession(ctx context.Context) error { - l.logger.Info("Message queue token is expired during GetNextMessage, refreshing...") - session, err := l.client.RefreshMessageSession(ctx, l.session.RunnerScaleSet.Id, l.session.SessionId) - if err != nil { - return fmt.Errorf("refresh message session failed. %w", err) - } - - l.session = session - return nil -} - -func (l *Listener) deleteMessageSession() error { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - l.logger.Info("Deleting message session") - - if err := l.client.DeleteMessageSession(ctx, l.session.RunnerScaleSet.Id, l.session.SessionId); err != nil { - return fmt.Errorf("failed to delete message session: %w", err) - } - - return nil -} diff --git a/cmd/ghalistener/listener/listener_test.go b/cmd/ghalistener/listener/listener_test.go deleted file mode 100644 index af3b256d..00000000 --- a/cmd/ghalistener/listener/listener_test.go +++ /dev/null @@ -1,970 +0,0 @@ -package listener - -import ( - "context" - "encoding/json" - "errors" - "net/http" - "testing" - "time" - - listenermocks "github.com/actions/actions-runner-controller/cmd/ghalistener/listener/mocks" - "github.com/actions/actions-runner-controller/cmd/ghalistener/metrics" - "github.com/actions/actions-runner-controller/github/actions" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -func TestNew(t *testing.T) { - t.Parallel() - t.Run("InvalidConfig", func(t *testing.T) { - t.Parallel() - var config Config - _, err := New(config) - assert.NotNil(t, err) - }) - - t.Run("ValidConfig", func(t *testing.T) { - t.Parallel() - config := Config{ - Client: listenermocks.NewClient(t), - ScaleSetID: 1, - Metrics: metrics.Discard, - } - l, err := New(config) - assert.Nil(t, err) - assert.NotNil(t, l) - }) -} - -func TestListener_createSession(t *testing.T) { - t.Parallel() - t.Run("FailOnce", func(t *testing.T) { - t.Parallel() - ctx := context.Background() - - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - } - - client := listenermocks.NewClient(t) - client.On("CreateMessageSession", ctx, mock.Anything, mock.Anything).Return(nil, assert.AnError).Once() - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - err = l.createSession(ctx) - assert.NotNil(t, err) - }) - - t.Run("FailContext", func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - } - - client := listenermocks.NewClient(t) - client.On("CreateMessageSession", ctx, mock.Anything, mock.Anything).Return(nil, - &actions.HttpClientSideError{Code: http.StatusConflict}).Once() - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - err = l.createSession(ctx) - assert.True(t, errors.Is(err, context.DeadlineExceeded)) - }) - - t.Run("SetsSession", func(t *testing.T) { - t.Parallel() - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - } - - client := listenermocks.NewClient(t) - - uuid := uuid.New() - session := &actions.RunnerScaleSetSession{ - SessionId: &uuid, - OwnerName: "example", - RunnerScaleSet: &actions.RunnerScaleSet{}, - MessageQueueUrl: "https://example.com", - MessageQueueAccessToken: "1234567890", - Statistics: nil, - } - client.On("CreateMessageSession", mock.Anything, mock.Anything, mock.Anything).Return(session, nil).Once() - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - err = l.createSession(context.Background()) - assert.Nil(t, err) - assert.Equal(t, session, l.session) - }) -} - -func TestListener_getMessage(t *testing.T) { - t.Parallel() - - t.Run("ReceivesMessage", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - MaxRunners: 10, - } - - client := listenermocks.NewClient(t) - want := &actions.RunnerScaleSetMessage{ - MessageId: 1, - } - client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10).Return(want, nil).Once() - config.Client = client - - l, err := New(config) - require.Nil(t, err) - l.session = &actions.RunnerScaleSetSession{} - - got, err := l.getMessage(ctx) - assert.Nil(t, err) - assert.Equal(t, want, got) - }) - - t.Run("NotExpiredError", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - MaxRunners: 10, - } - - client := listenermocks.NewClient(t) - client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10).Return(nil, &actions.HttpClientSideError{Code: http.StatusNotFound}).Once() - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - l.session = &actions.RunnerScaleSetSession{} - - _, err = l.getMessage(ctx) - assert.NotNil(t, err) - }) - - t.Run("RefreshAndSucceeds", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - MaxRunners: 10, - } - - client := listenermocks.NewClient(t) - - uuid := uuid.New() - session := &actions.RunnerScaleSetSession{ - SessionId: &uuid, - OwnerName: "example", - RunnerScaleSet: &actions.RunnerScaleSet{}, - MessageQueueUrl: "https://example.com", - MessageQueueAccessToken: "1234567890", - Statistics: nil, - } - client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() - - client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10).Return(nil, &actions.MessageQueueTokenExpiredError{}).Once() - - want := &actions.RunnerScaleSetMessage{ - MessageId: 1, - } - client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10).Return(want, nil).Once() - - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - l.session = &actions.RunnerScaleSetSession{ - SessionId: &uuid, - RunnerScaleSet: &actions.RunnerScaleSet{}, - } - - got, err := l.getMessage(ctx) - assert.Nil(t, err) - assert.Equal(t, want, got) - }) - - t.Run("RefreshAndFails", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - MaxRunners: 10, - } - - client := listenermocks.NewClient(t) - - uuid := uuid.New() - session := &actions.RunnerScaleSetSession{ - SessionId: &uuid, - OwnerName: "example", - RunnerScaleSet: &actions.RunnerScaleSet{}, - MessageQueueUrl: "https://example.com", - MessageQueueAccessToken: "1234567890", - Statistics: nil, - } - client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() - - client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10).Return(nil, &actions.MessageQueueTokenExpiredError{}).Twice() - - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - l.session = &actions.RunnerScaleSetSession{ - SessionId: &uuid, - RunnerScaleSet: &actions.RunnerScaleSet{}, - } - - got, err := l.getMessage(ctx) - assert.NotNil(t, err) - assert.Nil(t, got) - }) -} - -func TestListener_refreshSession(t *testing.T) { - t.Parallel() - - t.Run("SuccessfullyRefreshes", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - } - - client := listenermocks.NewClient(t) - - newUUID := uuid.New() - session := &actions.RunnerScaleSetSession{ - SessionId: &newUUID, - OwnerName: "example", - RunnerScaleSet: &actions.RunnerScaleSet{}, - MessageQueueUrl: "https://example.com", - MessageQueueAccessToken: "1234567890", - Statistics: nil, - } - client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() - - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - oldUUID := uuid.New() - l.session = &actions.RunnerScaleSetSession{ - SessionId: &oldUUID, - RunnerScaleSet: &actions.RunnerScaleSet{}, - } - - err = l.refreshSession(ctx) - assert.Nil(t, err) - assert.Equal(t, session, l.session) - }) - - t.Run("FailsToRefresh", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - } - - client := listenermocks.NewClient(t) - - client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(nil, errors.New("error")).Once() - - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - oldUUID := uuid.New() - oldSession := &actions.RunnerScaleSetSession{ - SessionId: &oldUUID, - RunnerScaleSet: &actions.RunnerScaleSet{}, - } - l.session = oldSession - - err = l.refreshSession(ctx) - assert.NotNil(t, err) - assert.Equal(t, oldSession, l.session) - }) -} - -func TestListener_deleteLastMessage(t *testing.T) { - t.Parallel() - - t.Run("SuccessfullyDeletes", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - } - - client := listenermocks.NewClient(t) - - client.On("DeleteMessage", ctx, mock.Anything, mock.Anything, mock.MatchedBy(func(lastMessageID any) bool { - return lastMessageID.(int64) == int64(5) - })).Return(nil).Once() - - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - l.session = &actions.RunnerScaleSetSession{} - l.lastMessageID = 5 - - err = l.deleteLastMessage(ctx) - assert.Nil(t, err) - }) - - t.Run("FailsToDelete", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - } - - client := listenermocks.NewClient(t) - - client.On("DeleteMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("error")).Once() - - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - l.session = &actions.RunnerScaleSetSession{} - l.lastMessageID = 5 - - err = l.deleteLastMessage(ctx) - assert.NotNil(t, err) - }) - - t.Run("RefreshAndSucceeds", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - } - - client := listenermocks.NewClient(t) - - newUUID := uuid.New() - session := &actions.RunnerScaleSetSession{ - SessionId: &newUUID, - OwnerName: "example", - RunnerScaleSet: &actions.RunnerScaleSet{}, - MessageQueueUrl: "https://example.com", - MessageQueueAccessToken: "1234567890", - Statistics: nil, - } - client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() - - client.On("DeleteMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(&actions.MessageQueueTokenExpiredError{}).Once() - - client.On("DeleteMessage", ctx, mock.Anything, mock.Anything, mock.MatchedBy(func(lastMessageID any) bool { - return lastMessageID.(int64) == int64(5) - })).Return(nil).Once() - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - oldUUID := uuid.New() - l.session = &actions.RunnerScaleSetSession{ - SessionId: &oldUUID, - RunnerScaleSet: &actions.RunnerScaleSet{}, - } - l.lastMessageID = 5 - - config.Client = client - - err = l.deleteLastMessage(ctx) - assert.NoError(t, err) - }) - - t.Run("RefreshAndFails", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - } - - client := listenermocks.NewClient(t) - - newUUID := uuid.New() - session := &actions.RunnerScaleSetSession{ - SessionId: &newUUID, - OwnerName: "example", - RunnerScaleSet: &actions.RunnerScaleSet{}, - MessageQueueUrl: "https://example.com", - MessageQueueAccessToken: "1234567890", - Statistics: nil, - } - client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() - - client.On("DeleteMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(&actions.MessageQueueTokenExpiredError{}).Twice() - - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - oldUUID := uuid.New() - l.session = &actions.RunnerScaleSetSession{ - SessionId: &oldUUID, - RunnerScaleSet: &actions.RunnerScaleSet{}, - } - l.lastMessageID = 5 - - config.Client = client - - err = l.deleteLastMessage(ctx) - assert.Error(t, err) - }) -} - -func TestListener_Listen(t *testing.T) { - t.Parallel() - - t.Run("CreateSessionFails", func(t *testing.T) { - t.Parallel() - ctx := context.Background() - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - } - - client := listenermocks.NewClient(t) - client.On("CreateMessageSession", ctx, mock.Anything, mock.Anything).Return(nil, assert.AnError).Once() - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - err = l.Listen(ctx, nil) - assert.NotNil(t, err) - }) - - t.Run("CallHandleRegardlessOfInitialMessage", func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithCancel(context.Background()) - - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - } - - client := listenermocks.NewClient(t) - - uuid := uuid.New() - session := &actions.RunnerScaleSetSession{ - SessionId: &uuid, - OwnerName: "example", - RunnerScaleSet: &actions.RunnerScaleSet{}, - MessageQueueUrl: "https://example.com", - MessageQueueAccessToken: "1234567890", - Statistics: &actions.RunnerScaleSetStatistic{}, - } - client.On("CreateMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() - client.On("DeleteMessageSession", mock.Anything, session.RunnerScaleSet.Id, session.SessionId).Return(nil).Once() - - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - var called bool - handler := listenermocks.NewHandler(t) - handler.On("HandleDesiredRunnerCount", mock.Anything, mock.Anything, 0). - Return(0, nil). - Run( - func(mock.Arguments) { - called = true - cancel() - }, - ). - Once() - - err = l.Listen(ctx, handler) - assert.True(t, errors.Is(err, context.Canceled)) - assert.True(t, called) - }) - - t.Run("CancelContextAfterGetMessage", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithCancel(context.Background()) - - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - MaxRunners: 10, - } - - client := listenermocks.NewClient(t) - uuid := uuid.New() - session := &actions.RunnerScaleSetSession{ - SessionId: &uuid, - OwnerName: "example", - RunnerScaleSet: &actions.RunnerScaleSet{}, - MessageQueueUrl: "https://example.com", - MessageQueueAccessToken: "1234567890", - Statistics: &actions.RunnerScaleSetStatistic{}, - } - client.On("CreateMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() - client.On("DeleteMessageSession", mock.Anything, session.RunnerScaleSet.Id, session.SessionId).Return(nil).Once() - - msg := &actions.RunnerScaleSetMessage{ - MessageId: 1, - MessageType: "RunnerScaleSetJobMessages", - Statistics: &actions.RunnerScaleSetStatistic{}, - } - client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything, 10). - Return(msg, nil). - Run( - func(mock.Arguments) { - cancel() - }, - ). - Once() - - // Ensure delete message is called without cancel - client.On("DeleteMessage", context.WithoutCancel(ctx), mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() - - config.Client = client - - handler := listenermocks.NewHandler(t) - handler.On("HandleDesiredRunnerCount", mock.Anything, mock.Anything, 0). - Return(0, nil). - Once() - - handler.On("HandleDesiredRunnerCount", mock.Anything, mock.Anything, 0). - Return(0, nil). - Once() - - l, err := New(config) - require.Nil(t, err) - - err = l.Listen(ctx, handler) - assert.ErrorIs(t, context.Canceled, err) - }) -} - -func TestListener_acquireAvailableJobs(t *testing.T) { - t.Parallel() - - t.Run("FailingToAcquireJobs", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - } - - client := listenermocks.NewClient(t) - - client.On("AcquireJobs", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil, assert.AnError).Once() - - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - uuid := uuid.New() - l.session = &actions.RunnerScaleSetSession{ - SessionId: &uuid, - OwnerName: "example", - RunnerScaleSet: &actions.RunnerScaleSet{}, - MessageQueueUrl: "https://example.com", - MessageQueueAccessToken: "1234567890", - Statistics: &actions.RunnerScaleSetStatistic{}, - } - - availableJobs := []*actions.JobAvailable{ - { - JobMessageBase: actions.JobMessageBase{ - RunnerRequestID: 1, - }, - }, - { - JobMessageBase: actions.JobMessageBase{ - RunnerRequestID: 2, - }, - }, - { - JobMessageBase: actions.JobMessageBase{ - RunnerRequestID: 3, - }, - }, - } - _, err = l.acquireAvailableJobs(ctx, availableJobs) - assert.Error(t, err) - }) - - t.Run("SuccessfullyAcquiresJobsOnFirstRun", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - } - - client := listenermocks.NewClient(t) - - jobIDs := []int64{1, 2, 3} - - client.On("AcquireJobs", ctx, mock.Anything, mock.Anything, mock.Anything).Return(jobIDs, nil).Once() - - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - uuid := uuid.New() - l.session = &actions.RunnerScaleSetSession{ - SessionId: &uuid, - OwnerName: "example", - RunnerScaleSet: &actions.RunnerScaleSet{}, - MessageQueueUrl: "https://example.com", - MessageQueueAccessToken: "1234567890", - Statistics: &actions.RunnerScaleSetStatistic{}, - } - - availableJobs := []*actions.JobAvailable{ - { - JobMessageBase: actions.JobMessageBase{ - RunnerRequestID: 1, - }, - }, - { - JobMessageBase: actions.JobMessageBase{ - RunnerRequestID: 2, - }, - }, - { - JobMessageBase: actions.JobMessageBase{ - RunnerRequestID: 3, - }, - }, - } - acquiredJobIDs, err := l.acquireAvailableJobs(ctx, availableJobs) - assert.NoError(t, err) - assert.Equal(t, []int64{1, 2, 3}, acquiredJobIDs) - }) - - t.Run("RefreshAndSucceeds", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - } - - client := listenermocks.NewClient(t) - - uuid := uuid.New() - session := &actions.RunnerScaleSetSession{ - SessionId: &uuid, - OwnerName: "example", - RunnerScaleSet: &actions.RunnerScaleSet{}, - MessageQueueUrl: "https://example.com", - MessageQueueAccessToken: "1234567890", - Statistics: nil, - } - client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() - - // Second call to AcquireJobs will succeed - want := []int64{1, 2, 3} - availableJobs := []*actions.JobAvailable{ - { - JobMessageBase: actions.JobMessageBase{ - RunnerRequestID: 1, - }, - }, - { - JobMessageBase: actions.JobMessageBase{ - RunnerRequestID: 2, - }, - }, - { - JobMessageBase: actions.JobMessageBase{ - RunnerRequestID: 3, - }, - }, - } - - // First call to AcquireJobs will fail with a token expired error - client.On("AcquireJobs", ctx, mock.Anything, mock.Anything, mock.Anything). - Run(func(args mock.Arguments) { - ids := args.Get(3).([]int64) - assert.Equal(t, want, ids) - }). - Return(nil, &actions.MessageQueueTokenExpiredError{}). - Once() - - // Second call should succeed - client.On("AcquireJobs", ctx, mock.Anything, mock.Anything, mock.Anything). - Run(func(args mock.Arguments) { - ids := args.Get(3).([]int64) - assert.Equal(t, want, ids) - }). - Return(want, nil). - Once() - - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - l.session = &actions.RunnerScaleSetSession{ - SessionId: &uuid, - RunnerScaleSet: &actions.RunnerScaleSet{}, - } - - got, err := l.acquireAvailableJobs(ctx, availableJobs) - assert.Nil(t, err) - assert.Equal(t, want, got) - }) - - t.Run("RefreshAndFails", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - config := Config{ - ScaleSetID: 1, - Metrics: metrics.Discard, - } - - client := listenermocks.NewClient(t) - - uuid := uuid.New() - session := &actions.RunnerScaleSetSession{ - SessionId: &uuid, - OwnerName: "example", - RunnerScaleSet: &actions.RunnerScaleSet{}, - MessageQueueUrl: "https://example.com", - MessageQueueAccessToken: "1234567890", - Statistics: nil, - } - client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() - - client.On("AcquireJobs", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil, &actions.MessageQueueTokenExpiredError{}).Twice() - - config.Client = client - - l, err := New(config) - require.Nil(t, err) - - l.session = &actions.RunnerScaleSetSession{ - SessionId: &uuid, - RunnerScaleSet: &actions.RunnerScaleSet{}, - } - - availableJobs := []*actions.JobAvailable{ - { - JobMessageBase: actions.JobMessageBase{ - RunnerRequestID: 1, - }, - }, - { - JobMessageBase: actions.JobMessageBase{ - RunnerRequestID: 2, - }, - }, - { - JobMessageBase: actions.JobMessageBase{ - RunnerRequestID: 3, - }, - }, - } - - got, err := l.acquireAvailableJobs(ctx, availableJobs) - assert.NotNil(t, err) - assert.Nil(t, got) - }) -} - -func TestListener_parseMessage(t *testing.T) { - t.Run("FailOnEmptyStatistics", func(t *testing.T) { - msg := &actions.RunnerScaleSetMessage{ - MessageId: 1, - MessageType: "RunnerScaleSetJobMessages", - Statistics: nil, - } - - l := &Listener{} - parsedMsg, err := l.parseMessage(context.Background(), msg) - assert.Error(t, err) - assert.Nil(t, parsedMsg) - }) - - t.Run("FailOnIncorrectMessageType", func(t *testing.T) { - msg := &actions.RunnerScaleSetMessage{ - MessageId: 1, - MessageType: "RunnerMessages", // arbitrary message type - Statistics: &actions.RunnerScaleSetStatistic{}, - } - - l := &Listener{} - parsedMsg, err := l.parseMessage(context.Background(), msg) - assert.Error(t, err) - assert.Nil(t, parsedMsg) - }) - - t.Run("ParseAll", func(t *testing.T) { - msg := &actions.RunnerScaleSetMessage{ - MessageId: 1, - MessageType: "RunnerScaleSetJobMessages", - Body: "", - Statistics: &actions.RunnerScaleSetStatistic{ - TotalAvailableJobs: 1, - TotalAcquiredJobs: 2, - TotalAssignedJobs: 3, - TotalRunningJobs: 4, - TotalRegisteredRunners: 5, - TotalBusyRunners: 6, - TotalIdleRunners: 7, - }, - } - - var batchedMessages []any - jobsAvailable := []*actions.JobAvailable{ - { - AcquireJobUrl: "https://github.com/example", - JobMessageBase: actions.JobMessageBase{ - JobMessageType: actions.JobMessageType{ - MessageType: messageTypeJobAvailable, - }, - RunnerRequestID: 1, - }, - }, - { - AcquireJobUrl: "https://github.com/example", - JobMessageBase: actions.JobMessageBase{ - JobMessageType: actions.JobMessageType{ - MessageType: messageTypeJobAvailable, - }, - RunnerRequestID: 2, - }, - }, - } - for _, msg := range jobsAvailable { - batchedMessages = append(batchedMessages, msg) - } - - jobsAssigned := []*actions.JobAssigned{ - { - JobMessageBase: actions.JobMessageBase{ - JobMessageType: actions.JobMessageType{ - MessageType: messageTypeJobAssigned, - }, - RunnerRequestID: 3, - }, - }, - { - JobMessageBase: actions.JobMessageBase{ - JobMessageType: actions.JobMessageType{ - MessageType: messageTypeJobAssigned, - }, - RunnerRequestID: 4, - }, - }, - } - for _, msg := range jobsAssigned { - batchedMessages = append(batchedMessages, msg) - } - - jobsStarted := []*actions.JobStarted{ - { - JobMessageBase: actions.JobMessageBase{ - JobMessageType: actions.JobMessageType{ - MessageType: messageTypeJobStarted, - }, - RunnerRequestID: 5, - }, - RunnerID: 2, - RunnerName: "runner2", - }, - } - for _, msg := range jobsStarted { - batchedMessages = append(batchedMessages, msg) - } - - jobsCompleted := []*actions.JobCompleted{ - { - JobMessageBase: actions.JobMessageBase{ - JobMessageType: actions.JobMessageType{ - MessageType: messageTypeJobCompleted, - }, - RunnerRequestID: 6, - }, - Result: "success", - RunnerId: 1, - RunnerName: "runner1", - }, - } - for _, msg := range jobsCompleted { - batchedMessages = append(batchedMessages, msg) - } - - b, err := json.Marshal(batchedMessages) - require.NoError(t, err) - - msg.Body = string(b) - - l := &Listener{} - parsedMsg, err := l.parseMessage(context.Background(), msg) - require.NoError(t, err) - - assert.Equal(t, msg.Statistics, parsedMsg.statistics) - assert.Equal(t, jobsAvailable, parsedMsg.jobsAvailable) - assert.Equal(t, jobsStarted, parsedMsg.jobsStarted) - assert.Equal(t, jobsCompleted, parsedMsg.jobsCompleted) - }) -} diff --git a/cmd/ghalistener/listener/metrics_test.go b/cmd/ghalistener/listener/metrics_test.go deleted file mode 100644 index 975619b9..00000000 --- a/cmd/ghalistener/listener/metrics_test.go +++ /dev/null @@ -1,205 +0,0 @@ -package listener - -import ( - "context" - "encoding/json" - "testing" - - listenermocks "github.com/actions/actions-runner-controller/cmd/ghalistener/listener/mocks" - metricsmocks "github.com/actions/actions-runner-controller/cmd/ghalistener/metrics/mocks" - "github.com/actions/actions-runner-controller/github/actions" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -func TestInitialMetrics(t *testing.T) { - t.Parallel() - - t.Run("SetStaticMetrics", func(t *testing.T) { - t.Parallel() - - metrics := metricsmocks.NewPublisher(t) - - minRunners := 5 - maxRunners := 10 - metrics.On("PublishStatic", minRunners, maxRunners).Once() - - config := Config{ - Client: listenermocks.NewClient(t), - ScaleSetID: 1, - Metrics: metrics, - MinRunners: minRunners, - MaxRunners: maxRunners, - } - l, err := New(config) - - assert.Nil(t, err) - assert.NotNil(t, l) - }) - - t.Run("InitialMessageStatistics", func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithCancel(context.Background()) - - sessionStatistics := &actions.RunnerScaleSetStatistic{ - TotalAvailableJobs: 1, - TotalAcquiredJobs: 2, - TotalAssignedJobs: 3, - TotalRunningJobs: 4, - TotalRegisteredRunners: 5, - TotalBusyRunners: 6, - TotalIdleRunners: 7, - } - - uuid := uuid.New() - session := &actions.RunnerScaleSetSession{ - SessionId: &uuid, - OwnerName: "example", - RunnerScaleSet: &actions.RunnerScaleSet{}, - MessageQueueUrl: "https://example.com", - MessageQueueAccessToken: "1234567890", - Statistics: sessionStatistics, - } - - metrics := metricsmocks.NewPublisher(t) - metrics.On("PublishStatic", mock.Anything, mock.Anything).Once() - metrics.On("PublishStatistics", sessionStatistics).Once() - metrics.On("PublishDesiredRunners", sessionStatistics.TotalAssignedJobs). - Run( - func(mock.Arguments) { - cancel() - }, - ).Once() - - config := Config{ - Client: listenermocks.NewClient(t), - ScaleSetID: 1, - Metrics: metrics, - } - - client := listenermocks.NewClient(t) - client.On("CreateMessageSession", mock.Anything, mock.Anything, mock.Anything).Return(session, nil).Once() - client.On("DeleteMessageSession", mock.Anything, session.RunnerScaleSet.Id, session.SessionId).Return(nil).Once() - config.Client = client - - handler := listenermocks.NewHandler(t) - handler.On("HandleDesiredRunnerCount", mock.Anything, sessionStatistics.TotalAssignedJobs, 0). - Return(sessionStatistics.TotalAssignedJobs, nil). - Once() - - l, err := New(config) - assert.Nil(t, err) - assert.NotNil(t, l) - - assert.ErrorIs(t, context.Canceled, l.Listen(ctx, handler)) - }) -} - -func TestHandleMessageMetrics(t *testing.T) { - t.Parallel() - - msg := &actions.RunnerScaleSetMessage{ - MessageId: 1, - MessageType: "RunnerScaleSetJobMessages", - Body: "", - Statistics: &actions.RunnerScaleSetStatistic{ - TotalAvailableJobs: 1, - TotalAcquiredJobs: 2, - TotalAssignedJobs: 3, - TotalRunningJobs: 4, - TotalRegisteredRunners: 5, - TotalBusyRunners: 6, - TotalIdleRunners: 7, - }, - } - - var batchedMessages []any - jobsStarted := []*actions.JobStarted{ - { - JobMessageBase: actions.JobMessageBase{ - JobMessageType: actions.JobMessageType{ - MessageType: messageTypeJobStarted, - }, - RunnerRequestID: 8, - }, - RunnerID: 3, - RunnerName: "runner3", - }, - } - for _, msg := range jobsStarted { - batchedMessages = append(batchedMessages, msg) - } - - jobsCompleted := []*actions.JobCompleted{ - { - JobMessageBase: actions.JobMessageBase{ - JobMessageType: actions.JobMessageType{ - MessageType: messageTypeJobCompleted, - }, - RunnerRequestID: 6, - }, - Result: "success", - RunnerId: 1, - RunnerName: "runner1", - }, - { - JobMessageBase: actions.JobMessageBase{ - JobMessageType: actions.JobMessageType{ - MessageType: messageTypeJobCompleted, - }, - RunnerRequestID: 7, - }, - Result: "success", - RunnerId: 2, - RunnerName: "runner2", - }, - } - for _, msg := range jobsCompleted { - batchedMessages = append(batchedMessages, msg) - } - - b, err := json.Marshal(batchedMessages) - require.NoError(t, err) - - msg.Body = string(b) - - desiredResult := 4 - - metrics := metricsmocks.NewPublisher(t) - metrics.On("PublishStatic", 0, 0).Once() - metrics.On("PublishStatistics", msg.Statistics).Once() - metrics.On("PublishJobCompleted", jobsCompleted[0]).Once() - metrics.On("PublishJobCompleted", jobsCompleted[1]).Once() - metrics.On("PublishJobStarted", jobsStarted[0]).Once() - metrics.On("PublishDesiredRunners", desiredResult).Once() - - handler := listenermocks.NewHandler(t) - handler.On("HandleJobStarted", mock.Anything, jobsStarted[0]).Return(nil).Once() - handler.On("HandleDesiredRunnerCount", mock.Anything, mock.Anything, 2).Return(desiredResult, nil).Once() - - client := listenermocks.NewClient(t) - client.On("DeleteMessage", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() - - config := Config{ - Client: listenermocks.NewClient(t), - ScaleSetID: 1, - Metrics: metrics, - } - - l, err := New(config) - require.NoError(t, err) - l.client = client - l.session = &actions.RunnerScaleSetSession{ - OwnerName: "", - RunnerScaleSet: &actions.RunnerScaleSet{}, - MessageQueueUrl: "", - MessageQueueAccessToken: "", - Statistics: &actions.RunnerScaleSetStatistic{}, - } - - err = l.handleMessage(context.Background(), handler, msg) - require.NoError(t, err) -} diff --git a/cmd/ghalistener/listener/mocks/client.go b/cmd/ghalistener/listener/mocks/client.go deleted file mode 100644 index a36c9344..00000000 --- a/cmd/ghalistener/listener/mocks/client.go +++ /dev/null @@ -1,190 +0,0 @@ -// Code generated by mockery v2.36.1. DO NOT EDIT. - -package mocks - -import ( - context "context" - - actions "github.com/actions/actions-runner-controller/github/actions" - - mock "github.com/stretchr/testify/mock" - - uuid "github.com/google/uuid" -) - -// Client is an autogenerated mock type for the Client type -type Client struct { - mock.Mock -} - -// AcquireJobs provides a mock function with given fields: ctx, runnerScaleSetId, messageQueueAccessToken, requestIds -func (_m *Client) AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQueueAccessToken string, requestIds []int64) ([]int64, error) { - ret := _m.Called(ctx, runnerScaleSetId, messageQueueAccessToken, requestIds) - - var r0 []int64 - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, int, string, []int64) ([]int64, error)); ok { - return rf(ctx, runnerScaleSetId, messageQueueAccessToken, requestIds) - } - if rf, ok := ret.Get(0).(func(context.Context, int, string, []int64) []int64); ok { - r0 = rf(ctx, runnerScaleSetId, messageQueueAccessToken, requestIds) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]int64) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, int, string, []int64) error); ok { - r1 = rf(ctx, runnerScaleSetId, messageQueueAccessToken, requestIds) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// CreateMessageSession provides a mock function with given fields: ctx, runnerScaleSetId, owner -func (_m *Client) CreateMessageSession(ctx context.Context, runnerScaleSetId int, owner string) (*actions.RunnerScaleSetSession, error) { - ret := _m.Called(ctx, runnerScaleSetId, owner) - - var r0 *actions.RunnerScaleSetSession - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, int, string) (*actions.RunnerScaleSetSession, error)); ok { - return rf(ctx, runnerScaleSetId, owner) - } - if rf, ok := ret.Get(0).(func(context.Context, int, string) *actions.RunnerScaleSetSession); ok { - r0 = rf(ctx, runnerScaleSetId, owner) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*actions.RunnerScaleSetSession) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, int, string) error); ok { - r1 = rf(ctx, runnerScaleSetId, owner) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// DeleteMessage provides a mock function with given fields: ctx, messageQueueUrl, messageQueueAccessToken, messageId -func (_m *Client) DeleteMessage(ctx context.Context, messageQueueUrl string, messageQueueAccessToken string, messageId int64) error { - ret := _m.Called(ctx, messageQueueUrl, messageQueueAccessToken, messageId) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) error); ok { - r0 = rf(ctx, messageQueueUrl, messageQueueAccessToken, messageId) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// DeleteMessageSession provides a mock function with given fields: ctx, runnerScaleSetId, sessionId -func (_m *Client) DeleteMessageSession(ctx context.Context, runnerScaleSetId int, sessionId *uuid.UUID) error { - ret := _m.Called(ctx, runnerScaleSetId, sessionId) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, int, *uuid.UUID) error); ok { - r0 = rf(ctx, runnerScaleSetId, sessionId) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// GetAcquirableJobs provides a mock function with given fields: ctx, runnerScaleSetId -func (_m *Client) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (*actions.AcquirableJobList, error) { - ret := _m.Called(ctx, runnerScaleSetId) - - var r0 *actions.AcquirableJobList - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, int) (*actions.AcquirableJobList, error)); ok { - return rf(ctx, runnerScaleSetId) - } - if rf, ok := ret.Get(0).(func(context.Context, int) *actions.AcquirableJobList); ok { - r0 = rf(ctx, runnerScaleSetId) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*actions.AcquirableJobList) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { - r1 = rf(ctx, runnerScaleSetId) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// GetMessage provides a mock function with given fields: ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity -func (_m *Client) GetMessage(ctx context.Context, messageQueueUrl string, messageQueueAccessToken string, lastMessageId int64, maxCapacity int) (*actions.RunnerScaleSetMessage, error) { - ret := _m.Called(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity) - - var r0 *actions.RunnerScaleSetMessage - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, int) (*actions.RunnerScaleSetMessage, error)); ok { - return rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, int) *actions.RunnerScaleSetMessage); ok { - r0 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*actions.RunnerScaleSetMessage) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string, string, int64, int) error); ok { - r1 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId, maxCapacity) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// RefreshMessageSession provides a mock function with given fields: ctx, runnerScaleSetId, sessionId -func (_m *Client) RefreshMessageSession(ctx context.Context, runnerScaleSetId int, sessionId *uuid.UUID) (*actions.RunnerScaleSetSession, error) { - ret := _m.Called(ctx, runnerScaleSetId, sessionId) - - var r0 *actions.RunnerScaleSetSession - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, int, *uuid.UUID) (*actions.RunnerScaleSetSession, error)); ok { - return rf(ctx, runnerScaleSetId, sessionId) - } - if rf, ok := ret.Get(0).(func(context.Context, int, *uuid.UUID) *actions.RunnerScaleSetSession); ok { - r0 = rf(ctx, runnerScaleSetId, sessionId) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*actions.RunnerScaleSetSession) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, int, *uuid.UUID) error); ok { - r1 = rf(ctx, runnerScaleSetId, sessionId) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// NewClient creates a new instance of Client. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewClient(t interface { - mock.TestingT - Cleanup(func()) -}) *Client { - mock := &Client{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/cmd/ghalistener/listener/mocks/handler.go b/cmd/ghalistener/listener/mocks/handler.go deleted file mode 100644 index b910d79f..00000000 --- a/cmd/ghalistener/listener/mocks/handler.go +++ /dev/null @@ -1,68 +0,0 @@ -// Code generated by mockery v2.36.1. DO NOT EDIT. - -package mocks - -import ( - context "context" - - actions "github.com/actions/actions-runner-controller/github/actions" - - mock "github.com/stretchr/testify/mock" -) - -// Handler is an autogenerated mock type for the Handler type -type Handler struct { - mock.Mock -} - -// HandleDesiredRunnerCount provides a mock function with given fields: ctx, count, jobsCompleted -func (_m *Handler) HandleDesiredRunnerCount(ctx context.Context, count int, jobsCompleted int) (int, error) { - ret := _m.Called(ctx, count, jobsCompleted) - - var r0 int - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, int, int) (int, error)); ok { - return rf(ctx, count, jobsCompleted) - } - if rf, ok := ret.Get(0).(func(context.Context, int, int) int); ok { - r0 = rf(ctx, count, jobsCompleted) - } else { - r0 = ret.Get(0).(int) - } - - if rf, ok := ret.Get(1).(func(context.Context, int, int) error); ok { - r1 = rf(ctx, count, jobsCompleted) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// HandleJobStarted provides a mock function with given fields: ctx, jobInfo -func (_m *Handler) HandleJobStarted(ctx context.Context, jobInfo *actions.JobStarted) error { - ret := _m.Called(ctx, jobInfo) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *actions.JobStarted) error); ok { - r0 = rf(ctx, jobInfo) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// NewHandler creates a new instance of Handler. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewHandler(t interface { - mock.TestingT - Cleanup(func()) -}) *Handler { - mock := &Handler{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/cmd/ghalistener/metrics/metrics.go b/cmd/ghalistener/metrics/metrics.go index 7f51e3e9..c894f050 100644 --- a/cmd/ghalistener/metrics/metrics.go +++ b/cmd/ghalistener/metrics/metrics.go @@ -9,8 +9,7 @@ import ( "time" "github.com/actions/actions-runner-controller/apis/actions.github.com/v1alpha1" - "github.com/actions/actions-runner-controller/github/actions" - "github.com/go-logr/logr" + "github.com/actions/scaleset" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" ) @@ -77,7 +76,7 @@ var metricsHelp = metricsHelpRegistry{ }, } -func (e *exporter) jobLabels(jobBase *actions.JobMessageBase) prometheus.Labels { +func (e *exporter) jobLabels(jobBase *scaleset.JobMessageBase) prometheus.Labels { workflowRefInfo := ParseWorkflowRef(jobBase.JobWorkflowRef) return prometheus.Labels{ labelKeyEnterprise: e.scaleSetLabels[labelKeyEnterprise], @@ -91,22 +90,22 @@ func (e *exporter) jobLabels(jobBase *actions.JobMessageBase) prometheus.Labels } } -func (e *exporter) completedJobLabels(msg *actions.JobCompleted) prometheus.Labels { +func (e *exporter) completedJobLabels(msg *scaleset.JobCompleted) prometheus.Labels { l := e.jobLabels(&msg.JobMessageBase) l[labelKeyJobResult] = msg.Result return l } -func (e *exporter) startedJobLabels(msg *actions.JobStarted) prometheus.Labels { +func (e *exporter) startedJobLabels(msg *scaleset.JobStarted) prometheus.Labels { return e.jobLabels(&msg.JobMessageBase) } //go:generate mockery --name Publisher --output ./mocks --outpkg mocks --case underscore type Publisher interface { PublishStatic(min, max int) - PublishStatistics(stats *actions.RunnerScaleSetStatistic) - PublishJobStarted(msg *actions.JobStarted) - PublishJobCompleted(msg *actions.JobCompleted) + PublishStatistics(stats *scaleset.RunnerScaleSetStatistic) + PublishJobStarted(msg *scaleset.JobStarted) + PublishJobCompleted(msg *scaleset.JobCompleted) PublishDesiredRunners(count int) } @@ -124,7 +123,7 @@ var ( var Discard Publisher = &discard{} type exporter struct { - logger logr.Logger + logger *slog.Logger scaleSetLabels prometheus.Labels *metrics srv *http.Server @@ -310,7 +309,7 @@ func NewExporter(config ExporterConfig) ServerExporter { ) return &exporter{ - logger: config.Logger.WithName("metrics"), + logger: config.Logger.With("component", "metrics exporter"), scaleSetLabels: prometheus.Labels{ labelKeyRunnerScaleSetName: config.ScaleSetName, labelKeyRunnerScaleSetNamespace: config.ScaleSetNamespace, @@ -328,7 +327,7 @@ func NewExporter(config ExporterConfig) ServerExporter { var errUnknownMetricName = errors.New("unknown metric name") -func installMetrics(config v1alpha1.MetricsConfig, reg *prometheus.Registry, logger logr.Logger) *metrics { +func installMetrics(config v1alpha1.MetricsConfig, reg *prometheus.Registry, logger *slog.Logger) *metrics { logger.Info( "Registering metrics", "gauges", @@ -346,7 +345,7 @@ func installMetrics(config v1alpha1.MetricsConfig, reg *prometheus.Registry, log for name, cfg := range config.Gauges { help, ok := metricsHelp.gauges[name] if !ok { - logger.Error(errUnknownMetricName, "name", name, "kind", "gauge") + logger.Error("name", name, "kind", "gauge", "error", errUnknownMetricName.Error()) continue } @@ -368,7 +367,7 @@ func installMetrics(config v1alpha1.MetricsConfig, reg *prometheus.Registry, log for name, cfg := range config.Counters { help, ok := metricsHelp.counters[name] if !ok { - logger.Error(errUnknownMetricName, "name", name, "kind", "counter") + logger.Error("name", name, "kind", "counter", "error", errUnknownMetricName.Error()) continue } c := prometheus.V2.NewCounterVec(prometheus.CounterVecOpts{ @@ -389,7 +388,7 @@ func installMetrics(config v1alpha1.MetricsConfig, reg *prometheus.Registry, log for name, cfg := range config.Histograms { help, ok := metricsHelp.histograms[name] if !ok { - logger.Error(errUnknownMetricName, "name", name, "kind", "histogram") + logger.Error("name", name, "kind", "histogram", "error", errUnknownMetricName.Error()) continue } @@ -470,7 +469,7 @@ func (e *exporter) PublishStatic(min, max int) { e.setGauge(MetricMinRunners, e.scaleSetLabels, float64(min)) } -func (e *exporter) PublishStatistics(stats *actions.RunnerScaleSetStatistic) { +func (e *exporter) PublishStatistics(stats *scaleset.RunnerScaleSetStatistic) { e.setGauge(MetricAssignedJobs, e.scaleSetLabels, float64(stats.TotalAssignedJobs)) e.setGauge(MetricRunningJobs, e.scaleSetLabels, float64(stats.TotalRunningJobs)) e.setGauge(MetricRegisteredRunners, e.scaleSetLabels, float64(stats.TotalRegisteredRunners)) @@ -478,7 +477,7 @@ func (e *exporter) PublishStatistics(stats *actions.RunnerScaleSetStatistic) { e.setGauge(MetricIdleRunners, e.scaleSetLabels, float64(stats.TotalIdleRunners)) } -func (e *exporter) PublishJobStarted(msg *actions.JobStarted) { +func (e *exporter) PublishJobStarted(msg *scaleset.JobStarted) { l := e.startedJobLabels(msg) e.incCounter(MetricStartedJobsTotal, l) @@ -486,7 +485,7 @@ func (e *exporter) PublishJobStarted(msg *actions.JobStarted) { e.observeHistogram(MetricJobStartupDurationSeconds, l, float64(startupDuration)) } -func (e *exporter) PublishJobCompleted(msg *actions.JobCompleted) { +func (e *exporter) PublishJobCompleted(msg *scaleset.JobCompleted) { l := e.completedJobLabels(msg) e.incCounter(MetricCompletedJobsTotal, l) @@ -500,11 +499,11 @@ func (e *exporter) PublishDesiredRunners(count int) { type discard struct{} -func (*discard) PublishStatic(int, int) {} -func (*discard) PublishStatistics(*actions.RunnerScaleSetStatistic) {} -func (*discard) PublishJobStarted(*actions.JobStarted) {} -func (*discard) PublishJobCompleted(*actions.JobCompleted) {} -func (*discard) PublishDesiredRunners(int) {} +func (*discard) PublishStatic(int, int) {} +func (*discard) PublishStatistics(*scaleset.RunnerScaleSetStatistic) {} +func (*discard) PublishJobStarted(*scaleset.JobStarted) {} +func (*discard) PublishJobCompleted(*scaleset.JobCompleted) {} +func (*discard) PublishDesiredRunners(int) {} var defaultRuntimeBuckets []float64 = []float64{ 0.01, diff --git a/cmd/ghalistener/metrics/metrics_integration_test.go b/cmd/ghalistener/metrics/metrics_integration_test.go index a0e41ae0..264fddf8 100644 --- a/cmd/ghalistener/metrics/metrics_integration_test.go +++ b/cmd/ghalistener/metrics/metrics_integration_test.go @@ -3,7 +3,7 @@ package metrics import ( "testing" - "github.com/actions/actions-runner-controller/github/actions" + "github.com/actions/scaleset" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" ) @@ -22,13 +22,13 @@ func TestMetricsWithWorkflowRefParsing(t *testing.T) { tests := []struct { name string - jobBase actions.JobMessageBase + jobBase scaleset.JobMessageBase wantName string wantTarget string }{ { name: "main branch workflow", - jobBase: actions.JobMessageBase{ + jobBase: scaleset.JobMessageBase{ OwnerName: "actions", RepositoryName: "runner", JobDisplayName: "Build and Test", @@ -40,7 +40,7 @@ func TestMetricsWithWorkflowRefParsing(t *testing.T) { }, { name: "feature branch workflow", - jobBase: actions.JobMessageBase{ + jobBase: scaleset.JobMessageBase{ OwnerName: "myorg", RepositoryName: "myrepo", JobDisplayName: "CI/CD Pipeline", @@ -52,7 +52,7 @@ func TestMetricsWithWorkflowRefParsing(t *testing.T) { }, { name: "pull request workflow", - jobBase: actions.JobMessageBase{ + jobBase: scaleset.JobMessageBase{ OwnerName: "actions", RepositoryName: "runner", JobDisplayName: "PR Checks", @@ -64,7 +64,7 @@ func TestMetricsWithWorkflowRefParsing(t *testing.T) { }, { name: "tag workflow", - jobBase: actions.JobMessageBase{ + jobBase: scaleset.JobMessageBase{ OwnerName: "actions", RepositoryName: "runner", JobDisplayName: "Release", diff --git a/cmd/ghalistener/metrics/metrics_test.go b/cmd/ghalistener/metrics/metrics_test.go index 850560fb..e62d77e7 100644 --- a/cmd/ghalistener/metrics/metrics_test.go +++ b/cmd/ghalistener/metrics/metrics_test.go @@ -1,15 +1,17 @@ package metrics import ( + "log/slog" "testing" "github.com/actions/actions-runner-controller/apis/actions.github.com/v1alpha1" - "github.com/go-logr/logr" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +var discardLogger = slog.New(slog.DiscardHandler) + func TestInstallMetrics(t *testing.T) { metricsConfig := v1alpha1.MetricsConfig{ Counters: map[string]*v1alpha1.CounterMetric{ @@ -74,7 +76,7 @@ func TestInstallMetrics(t *testing.T) { } reg := prometheus.NewRegistry() - got := installMetrics(metricsConfig, reg, logr.Discard()) + got := installMetrics(metricsConfig, reg, discardLogger) assert.Len(t, got.counters, 1) assert.Len(t, got.gauges, 1) assert.Len(t, got.histograms, 2) @@ -98,7 +100,7 @@ func TestNewExporter(t *testing.T) { Repository: "repo", ServerAddr: ":6060", ServerEndpoint: "/metrics", - Logger: logr.Discard(), + Logger: discardLogger, Metrics: nil, // when metrics is nil, all default metrics should be registered } @@ -140,7 +142,7 @@ func TestNewExporter(t *testing.T) { Repository: "repo", ServerAddr: "", // empty ServerAddr should default to ":8080" ServerEndpoint: "", - Logger: logr.Discard(), + Logger: discardLogger, Metrics: nil, // when metrics is nil, all default metrics should be registered } @@ -201,7 +203,7 @@ func TestNewExporter(t *testing.T) { Repository: "repo", ServerAddr: ":6060", ServerEndpoint: "/metrics", - Logger: logr.Discard(), + Logger: discardLogger, Metrics: &metricsConfig, } @@ -244,7 +246,7 @@ func TestExporterConfigDefaults(t *testing.T) { Repository: "repo", ServerAddr: "", ServerEndpoint: "", - Logger: logr.Discard(), + Logger: discardLogger, Metrics: nil, // when metrics is nil, all default metrics should be registered } @@ -257,7 +259,7 @@ func TestExporterConfigDefaults(t *testing.T) { Repository: "repo", ServerAddr: ":8080", // default server address ServerEndpoint: "/metrics", // default server endpoint - Logger: logr.Discard(), + Logger: discardLogger, Metrics: &defaultMetrics, // when metrics is nil, all default metrics should be registered } diff --git a/cmd/ghalistener/worker/worker.go b/cmd/ghalistener/worker/worker.go index bd039fb9..7e16a3eb 100644 --- a/cmd/ghalistener/worker/worker.go +++ b/cmd/ghalistener/worker/worker.go @@ -5,11 +5,11 @@ import ( "encoding/json" "fmt" "log/slog" + "math" "github.com/actions/actions-runner-controller/apis/actions.github.com/v1alpha1" - "github.com/actions/actions-runner-controller/cmd/ghalistener/listener" - "github.com/actions/actions-runner-controller/github/actions" - "github.com/actions/actions-runner-controller/logging" + "github.com/actions/scaleset" + "github.com/actions/scaleset/listener" jsonpatch "github.com/evanphx/json-patch" kerrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/types" @@ -35,20 +35,21 @@ type Config struct { // The Worker's role is to process the messages it receives from the listener. // It then initiates Kubernetes API requests to carry out the necessary actions. type Worker struct { - clientset *kubernetes.Clientset - config Config - lastPatch int - patchSeq int - logger *slog.Logger + clientset *kubernetes.Clientset + config Config + targetRunners int + patchSeq int + dirty bool + logger *slog.Logger } -var _ listener.Handler = (*Worker)(nil) +var _ listener.Scaler = (*Worker)(nil) func New(config Config, options ...Option) (*Worker, error) { w := &Worker{ - config: config, - lastPatch: -1, - patchSeq: -1, + config: config, + targetRunners: -1, + patchSeq: -1, } conf, err := rest.InClusterConfig() @@ -76,12 +77,7 @@ func New(config Config, options ...Option) (*Worker, error) { func (w *Worker) applyDefaults() error { if w.logger == nil { - logger, err := logging.NewLogger(logging.LogLevelDebug, logging.LogFormatJSON) - if err != nil { - return fmt.Errorf("NewLogger failed: %w", err) - } - logger = logger.WithName(workerName) - w.logger = &logger + w.logger = slog.New(slog.DiscardHandler) } return nil @@ -92,7 +88,7 @@ func (w *Worker) applyDefaults() error { // This update marks the ephemeral runner so that the controller would have more context // about the ephemeral runner that should not be deleted when scaling down. // It returns an error if there is any issue with updating the job information. -func (w *Worker) HandleJobStarted(ctx context.Context, jobInfo *actions.JobStarted) error { +func (w *Worker) HandleJobStarted(ctx context.Context, jobInfo *scaleset.JobStarted) error { w.logger.Info("Updating job info for the runner", "runnerName", jobInfo.RunnerName, "ownerName", jobInfo.OwnerName, @@ -103,6 +99,8 @@ func (w *Worker) HandleJobStarted(ctx context.Context, jobInfo *actions.JobStart "jobDisplayName", jobInfo.JobDisplayName, "requestId", jobInfo.RunnerRequestID) + w.dirty = true + original, err := json.Marshal(&v1alpha1.EphemeralRunner{}) if err != nil { return fmt.Errorf("failed to marshal empty ephemeral runner: %w", err) @@ -155,6 +153,11 @@ func (w *Worker) HandleJobStarted(ctx context.Context, jobInfo *actions.JobStart return nil } +func (w *Worker) HandleJobCompleted(ctx context.Context, msg *scaleset.JobCompleted) error { + w.dirty = true + return nil +} + // HandleDesiredRunnerCount handles the desired runner count by scaling the ephemeral runner set. // The function calculates the target runner count based on the minimum and maximum runner count configuration. // If the target runner count is the same as the last patched count, it skips patching and returns nil. @@ -162,8 +165,8 @@ func (w *Worker) HandleJobStarted(ctx context.Context, jobInfo *actions.JobStart // The function then scales the ephemeral runner set by applying the merge patch. // Finally, it logs the scaled ephemeral runner set details and returns nil if successful. // If any error occurs during the process, it returns an error with a descriptive message. -func (w *Worker) HandleDesiredRunnerCount(ctx context.Context, count, jobsCompleted int) (int, error) { - patchID := w.setDesiredWorkerState(count, jobsCompleted) +func (w *Worker) HandleDesiredRunnerCount(ctx context.Context, count int) (int, error) { + patchID := w.setDesiredWorkerState(count) original, err := json.Marshal( &v1alpha1.EphemeralRunnerSet{ @@ -180,13 +183,13 @@ func (w *Worker) HandleDesiredRunnerCount(ctx context.Context, count, jobsComple patch, err := json.Marshal( &v1alpha1.EphemeralRunnerSet{ Spec: v1alpha1.EphemeralRunnerSetSpec{ - Replicas: w.lastPatch, + Replicas: w.targetRunners, PatchID: patchID, }, }, ) if err != nil { - w.logger.Error(err, "could not marshal patch ephemeral runner set") + w.logger.Error("could not marshal patch ephemeral runner set", "error", err.Error()) return 0, err } @@ -217,30 +220,28 @@ func (w *Worker) HandleDesiredRunnerCount(ctx context.Context, count, jobsComple "name", w.config.EphemeralRunnerSetName, "replicas", patchedEphemeralRunnerSet.Spec.Replicas, ) - return w.lastPatch, nil + return w.targetRunners, nil } // calculateDesiredState calculates the desired state of the worker based on the desired count and the the number of jobs completed. -func (w *Worker) setDesiredWorkerState(count, jobsCompleted int) int { - // Max runners should always be set by the resource builder either to the configured value, - // or the maximum int32 (resourcebuilder.newAutoScalingListener()). - targetRunnerCount := min(w.config.MinRunners+count, w.config.MaxRunners) - w.patchSeq++ - desiredPatchID := w.patchSeq +func (w *Worker) setDesiredWorkerState(count int) int { + dirty := w.dirty + w.dirty = false - if count == 0 && jobsCompleted == 0 { // empty batch - targetRunnerCount = max(w.lastPatch, targetRunnerCount) - if targetRunnerCount == w.config.MinRunners { - // We have an empty batch, and the last patch was the min runners. - // Since this is an empty batch, and we are at the min runners, they should all be idle. - // If controller created few more pods on accident (during scale down events), - // this situation allows the controller to scale down to the min runners. - // However, it is important to keep the patch sequence increasing so we don't ignore one batch. - desiredPatchID = 0 - } + if w.patchSeq == math.MaxInt32 { + w.patchSeq = 0 + } + w.patchSeq++ + + desiredPatchID := w.patchSeq + if !dirty { + // If it is not dirty, meaning no jobs started and jobs finished. Since there are no events, + // we can force state by setting patch id to 0. + desiredPatchID = 0 } - w.lastPatch = targetRunnerCount + targetRunnerCount := min(w.config.MinRunners+count, w.config.MaxRunners) + w.targetRunners = targetRunnerCount w.logger.Info( "Calculated target runner count", @@ -248,8 +249,7 @@ func (w *Worker) setDesiredWorkerState(count, jobsCompleted int) int { "decision", targetRunnerCount, "min", w.config.MinRunners, "max", w.config.MaxRunners, - "currentRunnerCount", w.lastPatch, - "jobsCompleted", jobsCompleted, + "currentRunnerCount", w.targetRunners, ) return desiredPatchID diff --git a/cmd/ghalistener/worker/worker_test.go b/cmd/ghalistener/worker/worker_test.go index d009bccf..b10648c9 100644 --- a/cmd/ghalistener/worker/worker_test.go +++ b/cmd/ghalistener/worker/worker_test.go @@ -1,326 +1,377 @@ package worker import ( + "log/slog" "math" "testing" - "github.com/go-logr/logr" "github.com/stretchr/testify/assert" ) func TestSetDesiredWorkerState_MinMaxDefaults(t *testing.T) { - logger := logr.Discard() newEmptyWorker := func() *Worker { return &Worker{ config: Config{ MinRunners: 0, MaxRunners: math.MaxInt32, }, - lastPatch: -1, - patchSeq: -1, - logger: &logger, + targetRunners: -1, + patchSeq: -1, + logger: slog.New(slog.DiscardHandler), } } t.Run("init calculate with acquired 0", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(0, 0) - assert.Equal(t, 0, w.lastPatch) + patchID := w.setDesiredWorkerState(0) + assert.False(t, w.dirty) + assert.Equal(t, 0, w.targetRunners) assert.Equal(t, 0, w.patchSeq) assert.Equal(t, 0, patchID) }) t.Run("init calculate with acquired 1", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(1, 0) - assert.Equal(t, 1, w.lastPatch) + patchID := w.setDesiredWorkerState(1) + assert.False(t, w.dirty) + assert.Equal(t, 1, w.targetRunners) assert.Equal(t, 0, w.patchSeq) assert.Equal(t, 0, patchID) }) t.Run("increment patch when job done", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(1, 0) + patchID := w.setDesiredWorkerState(1) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - patchID = w.setDesiredWorkerState(0, 1) + w.dirty = true + patchID = w.setDesiredWorkerState(0) + assert.False(t, w.dirty) assert.Equal(t, 1, patchID) - assert.Equal(t, 0, w.lastPatch) + assert.Equal(t, 0, w.targetRunners) assert.Equal(t, 1, w.patchSeq) }) t.Run("increment patch when called with same parameters", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(1, 0) + patchID := w.setDesiredWorkerState(1) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - patchID = w.setDesiredWorkerState(1, 0) + patchID = w.setDesiredWorkerState(1) + assert.False(t, w.dirty) assert.Equal(t, 1, patchID) - assert.Equal(t, 1, w.lastPatch) + assert.Equal(t, 1, w.targetRunners) assert.Equal(t, 1, w.patchSeq) }) t.Run("calculate desired scale when acquired > 0 and completed > 0", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(1, 1) + w.dirty = true + patchID := w.setDesiredWorkerState(1) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - assert.Equal(t, 1, w.lastPatch) + assert.Equal(t, 1, w.targetRunners) assert.Equal(t, 0, w.patchSeq) }) t.Run("re-use the last state when acquired == 0 and completed == 0", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(1, 0) + patchID := w.setDesiredWorkerState(1) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - patchID = w.setDesiredWorkerState(0, 0) + patchID = w.setDesiredWorkerState(0) + assert.False(t, w.dirty) assert.Equal(t, 1, patchID) - assert.Equal(t, 1, w.lastPatch) + assert.Equal(t, 1, w.targetRunners) assert.Equal(t, 1, w.patchSeq) }) t.Run("adjust when acquired == 0 and completed == 1", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(1, 1) + w.dirty = true + patchID := w.setDesiredWorkerState(1) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - patchID = w.setDesiredWorkerState(0, 1) + assert.False(t, w.dirty) + w.dirty = true + patchID = w.setDesiredWorkerState(0) + assert.False(t, w.dirty) assert.Equal(t, 1, patchID) - assert.Equal(t, 0, w.lastPatch) + assert.Equal(t, 0, w.targetRunners) assert.Equal(t, 1, w.patchSeq) }) } func TestSetDesiredWorkerState_MinSet(t *testing.T) { - logger := logr.Discard() newEmptyWorker := func() *Worker { return &Worker{ config: Config{ MinRunners: 1, MaxRunners: math.MaxInt32, }, - lastPatch: -1, - patchSeq: -1, - logger: &logger, + targetRunners: -1, + patchSeq: -1, } } t.Run("initial scale when acquired == 0 and completed == 0", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(0, 0) + patchID := w.setDesiredWorkerState(0) + assert.False(t, w.dirty) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - assert.Equal(t, 1, w.lastPatch) + assert.Equal(t, 1, w.targetRunners) assert.Equal(t, 0, w.patchSeq) }) t.Run("re-use the old state on count == 0 and completed == 0", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(2, 0) + patchID := w.setDesiredWorkerState(2) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - patchID = w.setDesiredWorkerState(0, 0) + patchID = w.setDesiredWorkerState(0) + assert.False(t, w.dirty) assert.Equal(t, 1, patchID) - assert.Equal(t, 3, w.lastPatch) + assert.Equal(t, 3, w.targetRunners) assert.Equal(t, 1, w.patchSeq) }) t.Run("request back to 0 on job done", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(2, 0) + patchID := w.setDesiredWorkerState(2) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - patchID = w.setDesiredWorkerState(0, 1) + + w.dirty = true + patchID = w.setDesiredWorkerState(0) + assert.False(t, w.dirty) assert.Equal(t, 1, patchID) - assert.Equal(t, 1, w.lastPatch) + assert.Equal(t, 1, w.targetRunners) assert.Equal(t, 1, w.patchSeq) }) t.Run("desired patch is 0 but sequence continues on empty batch and min runners", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(3, 0) + patchID := w.setDesiredWorkerState(3) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - assert.Equal(t, 4, w.lastPatch) + assert.Equal(t, 4, w.targetRunners) assert.Equal(t, 0, w.patchSeq) - patchID = w.setDesiredWorkerState(0, 3) + w.dirty = true + patchID = w.setDesiredWorkerState(0) + assert.False(t, w.dirty) assert.Equal(t, 1, patchID) - assert.Equal(t, 1, w.lastPatch) + assert.Equal(t, 1, w.targetRunners) assert.Equal(t, 1, w.patchSeq) // Empty batch on min runners - patchID = w.setDesiredWorkerState(0, 0) + patchID = w.setDesiredWorkerState(0) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) // forcing the state - assert.Equal(t, 1, w.lastPatch) + assert.Equal(t, 1, w.targetRunners) assert.Equal(t, 2, w.patchSeq) }) - } func TestSetDesiredWorkerState_MaxSet(t *testing.T) { - logger := logr.Discard() newEmptyWorker := func() *Worker { return &Worker{ config: Config{ MinRunners: 0, MaxRunners: 5, }, - lastPatch: -1, - patchSeq: -1, - logger: &logger, + targetRunners: -1, + patchSeq: -1, } } t.Run("initial scale when acquired == 0 and completed == 0", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(0, 0) + patchID := w.setDesiredWorkerState(0) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - assert.Equal(t, 0, w.lastPatch) + assert.Equal(t, 0, w.targetRunners) assert.Equal(t, 0, w.patchSeq) }) t.Run("re-use the old state on count == 0 and completed == 0", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(2, 0) + patchID := w.setDesiredWorkerState(2) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - patchID = w.setDesiredWorkerState(0, 0) + patchID = w.setDesiredWorkerState(0) + assert.False(t, w.dirty) assert.Equal(t, 1, patchID) - assert.Equal(t, 2, w.lastPatch) + assert.Equal(t, 2, w.targetRunners) assert.Equal(t, 1, w.patchSeq) }) t.Run("request back to 0 on job done", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(2, 0) + patchID := w.setDesiredWorkerState(2) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - patchID = w.setDesiredWorkerState(0, 1) + + w.dirty = true + patchID = w.setDesiredWorkerState(0) + assert.False(t, w.dirty) assert.Equal(t, 1, patchID) - assert.Equal(t, 0, w.lastPatch) + assert.Equal(t, 0, w.targetRunners) assert.Equal(t, 1, w.patchSeq) }) t.Run("scale up to max when count > max", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(6, 0) + patchID := w.setDesiredWorkerState(6) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - assert.Equal(t, 5, w.lastPatch) + assert.Equal(t, 5, w.targetRunners) assert.Equal(t, 0, w.patchSeq) }) t.Run("scale to max when count == max", func(t *testing.T) { w := newEmptyWorker() - w.setDesiredWorkerState(5, 0) - assert.Equal(t, 5, w.lastPatch) + w.setDesiredWorkerState(5) + assert.False(t, w.dirty) + assert.Equal(t, 5, w.targetRunners) assert.Equal(t, 0, w.patchSeq) }) t.Run("scale to max when count > max and completed > 0", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(1, 0) + patchID := w.setDesiredWorkerState(1) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - patchID = w.setDesiredWorkerState(6, 1) + + w.dirty = true + patchID = w.setDesiredWorkerState(6) + assert.False(t, w.dirty) assert.Equal(t, 1, patchID) - assert.Equal(t, 5, w.lastPatch) + assert.Equal(t, 5, w.targetRunners) assert.Equal(t, 1, w.patchSeq) }) t.Run("scale back to 0 when count was > max", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(6, 0) + patchID := w.setDesiredWorkerState(6) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - patchID = w.setDesiredWorkerState(0, 1) + + w.dirty = true + patchID = w.setDesiredWorkerState(0) + assert.False(t, w.dirty) assert.Equal(t, 1, patchID) - assert.Equal(t, 0, w.lastPatch) + assert.Equal(t, 0, w.targetRunners) assert.Equal(t, 1, w.patchSeq) }) t.Run("force 0 on empty batch and last patch == min runners", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(3, 0) + patchID := w.setDesiredWorkerState(3) assert.Equal(t, 0, patchID) - assert.Equal(t, 3, w.lastPatch) + assert.Equal(t, 3, w.targetRunners) assert.Equal(t, 0, w.patchSeq) - patchID = w.setDesiredWorkerState(0, 3) + w.dirty = true + patchID = w.setDesiredWorkerState(0) assert.Equal(t, 1, patchID) - assert.Equal(t, 0, w.lastPatch) + assert.Equal(t, 0, w.targetRunners) assert.Equal(t, 1, w.patchSeq) // Empty batch on min runners - patchID = w.setDesiredWorkerState(0, 0) + patchID = w.setDesiredWorkerState(0) assert.Equal(t, 0, patchID) // forcing the state - assert.Equal(t, 0, w.lastPatch) + assert.Equal(t, 0, w.targetRunners) assert.Equal(t, 2, w.patchSeq) }) } func TestSetDesiredWorkerState_MinMaxSet(t *testing.T) { - logger := logr.Discard() newEmptyWorker := func() *Worker { return &Worker{ config: Config{ MinRunners: 1, MaxRunners: 3, }, - lastPatch: -1, - patchSeq: -1, - logger: &logger, + targetRunners: -1, + patchSeq: -1, } } t.Run("initial scale when acquired == 0 and completed == 0", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(0, 0) + patchID := w.setDesiredWorkerState(0) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - assert.Equal(t, 1, w.lastPatch) + assert.Equal(t, 1, w.targetRunners) assert.Equal(t, 0, w.patchSeq) }) t.Run("re-use the old state on count == 0 and completed == 0", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(2, 0) + patchID := w.setDesiredWorkerState(2) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - patchID = w.setDesiredWorkerState(0, 0) + patchID = w.setDesiredWorkerState(0) + assert.False(t, w.dirty) assert.Equal(t, 1, patchID) - assert.Equal(t, 3, w.lastPatch) + assert.Equal(t, 3, w.targetRunners) assert.Equal(t, 1, w.patchSeq) }) t.Run("scale to min when count == 0", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(2, 0) + patchID := w.setDesiredWorkerState(2) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - patchID = w.setDesiredWorkerState(0, 1) + + w.dirty = true + patchID = w.setDesiredWorkerState(0) + assert.False(t, w.dirty) assert.Equal(t, 1, patchID) - assert.Equal(t, 1, w.lastPatch) + assert.Equal(t, 1, w.targetRunners) assert.Equal(t, 1, w.patchSeq) }) t.Run("scale up to max when count > max", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(4, 0) + patchID := w.setDesiredWorkerState(4) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - assert.Equal(t, 3, w.lastPatch) + assert.Equal(t, 3, w.targetRunners) assert.Equal(t, 0, w.patchSeq) }) t.Run("scale to max when count == max", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(3, 0) + patchID := w.setDesiredWorkerState(3) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - assert.Equal(t, 3, w.lastPatch) + assert.Equal(t, 3, w.targetRunners) assert.Equal(t, 0, w.patchSeq) }) t.Run("force 0 on empty batch and last patch == min runners", func(t *testing.T) { w := newEmptyWorker() - patchID := w.setDesiredWorkerState(3, 0) + patchID := w.setDesiredWorkerState(3) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) - assert.Equal(t, 3, w.lastPatch) + assert.Equal(t, 3, w.targetRunners) assert.Equal(t, 0, w.patchSeq) - patchID = w.setDesiredWorkerState(0, 3) + w.dirty = true + patchID = w.setDesiredWorkerState(0) + assert.False(t, w.dirty) assert.Equal(t, 1, patchID) - assert.Equal(t, 1, w.lastPatch) + assert.Equal(t, 1, w.targetRunners) assert.Equal(t, 1, w.patchSeq) // Empty batch on min runners - patchID = w.setDesiredWorkerState(0, 0) + patchID = w.setDesiredWorkerState(0) + assert.False(t, w.dirty) assert.Equal(t, 0, patchID) // forcing the state - assert.Equal(t, 1, w.lastPatch) + assert.Equal(t, 1, w.targetRunners) assert.Equal(t, 2, w.patchSeq) }) } diff --git a/go.mod b/go.mod index 96908163..fb822ea6 100644 --- a/go.mod +++ b/go.mod @@ -48,7 +48,7 @@ require ( github.com/BurntSushi/toml v1.5.0 // indirect github.com/Masterminds/semver/v3 v3.4.0 // indirect github.com/ProtonMail/go-crypto v1.3.0 // indirect - github.com/actions/scaleset v0.1.0 // indirect + github.com/actions/scaleset v0.1.1-0.20260217091257-f9f801fb3898 // indirect github.com/aws/aws-sdk-go-v2 v1.39.2 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 // indirect github.com/aws/aws-sdk-go-v2/config v1.31.12 // indirect diff --git a/go.sum b/go.sum index a9344cb2..916b1908 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,8 @@ github.com/actions-runner-controller/httpcache v0.2.0 h1:hCNvYuVPJ2xxYBymqBvH0hS github.com/actions-runner-controller/httpcache v0.2.0/go.mod h1:JLu9/2M/btPz1Zu/vTZ71XzukQHn2YeISPmJoM5exBI= github.com/actions/scaleset v0.1.0 h1:Rzov5AqcphrQV+VfcPWUAK+hdVJzzJihr/qof1YjZx8= github.com/actions/scaleset v0.1.0/go.mod h1:ncR5vzCCTUSyLgvclAtZ5dRBgF6qwA2nbTfTXmOJp84= +github.com/actions/scaleset v0.1.1-0.20260217091257-f9f801fb3898 h1:IoEq9noTGLFlugkXM5Sr8P3dabGKcbAdMm8f7yx3Dw0= +github.com/actions/scaleset v0.1.1-0.20260217091257-f9f801fb3898/go.mod h1:ncR5vzCCTUSyLgvclAtZ5dRBgF6qwA2nbTfTXmOJp84= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/aws/aws-sdk-go-v2 v1.39.2 h1:EJLg8IdbzgeD7xgvZ+I8M1e0fL0ptn/M47lianzth0I=