diff --git a/cmd/ghalistener/app/app.go b/cmd/ghalistener/app/app.go index e8f64f21..2d903fa9 100644 --- a/cmd/ghalistener/app/app.go +++ b/cmd/ghalistener/app/app.go @@ -34,7 +34,7 @@ type Listener interface { //go:generate mockery --name Worker --output ./mocks --outpkg mocks --case underscore type Worker interface { HandleJobStarted(ctx context.Context, jobInfo *actions.JobStarted) error - HandleDesiredRunnerCount(ctx context.Context, desiredRunnerCount int) error + HandleDesiredRunnerCount(ctx context.Context, count int) (int, error) } func New(config config.Config) (*App, error) { diff --git a/cmd/ghalistener/app/mocks/worker.go b/cmd/ghalistener/app/mocks/worker.go index a2561adb..69828c38 100644 --- a/cmd/ghalistener/app/mocks/worker.go +++ b/cmd/ghalistener/app/mocks/worker.go @@ -15,18 +15,28 @@ type Worker struct { mock.Mock } -// HandleDesiredRunnerCount provides a mock function with given fields: ctx, desiredRunnerCount -func (_m *Worker) HandleDesiredRunnerCount(ctx context.Context, desiredRunnerCount int) error { - ret := _m.Called(ctx, desiredRunnerCount) +// HandleDesiredRunnerCount provides a mock function with given fields: ctx, count +func (_m *Worker) HandleDesiredRunnerCount(ctx context.Context, count int) (int, error) { + ret := _m.Called(ctx, count) - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { - r0 = rf(ctx, desiredRunnerCount) + var r0 int + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int) (int, error)); ok { + return rf(ctx, count) + } + if rf, ok := ret.Get(0).(func(context.Context, int) int); ok { + r0 = rf(ctx, count) } else { - r0 = ret.Error(0) + r0 = ret.Get(0).(int) } - return r0 + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, count) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // HandleJobStarted provides a mock function with given fields: ctx, jobInfo diff --git a/cmd/ghalistener/listener/listener.go b/cmd/ghalistener/listener/listener.go index e90622a0..cb611239 100644 --- a/cmd/ghalistener/listener/listener.go +++ b/cmd/ghalistener/listener/listener.go @@ -113,7 +113,7 @@ func New(config Config) (*Listener, error) { //go:generate mockery --name Handler --output ./mocks --outpkg mocks --case underscore type Handler interface { HandleJobStarted(ctx context.Context, jobInfo *actions.JobStarted) error - HandleDesiredRunnerCount(ctx context.Context, desiredRunnerCount int) error + HandleDesiredRunnerCount(ctx context.Context, count int) (int, error) } // Listen listens for incoming messages and handles them using the provided handler. @@ -133,28 +133,21 @@ func (l *Listener) Listen(ctx context.Context, handler Handler) error { Body: "", } - if l.session.Statistics.TotalAvailableJobs > 0 || l.session.Statistics.TotalAssignedJobs > 0 { - acquirableJobs, err := l.client.GetAcquirableJobs(ctx, l.scaleSetID) - if err != nil { - return fmt.Errorf("failed to call GetAcquirableJobs: %w", err) - } - - acquirableJobsJson, err := json.Marshal(acquirableJobs) - if err != nil { - return fmt.Errorf("failed to marshal acquirable jobs: %w", err) - } - - initialMessage.Body = string(acquirableJobsJson) + if l.session.Statistics == nil { + return fmt.Errorf("session statistics is nil") } + l.metrics.PublishStatistics(initialMessage.Statistics) - if err := handler.HandleDesiredRunnerCount(ctx, initialMessage.Statistics.TotalAssignedJobs); err != nil { + desiredRunners, err := handler.HandleDesiredRunnerCount(ctx, initialMessage.Statistics.TotalAssignedJobs) + if err != nil { return fmt.Errorf("handling initial message failed: %w", err) } + l.metrics.PublishDesiredRunners(desiredRunners) for { select { case <-ctx.Done(): - return fmt.Errorf("context cancelled: %w", ctx.Err()) + return ctx.Err() default: } @@ -167,29 +160,54 @@ func (l *Listener) Listen(ctx context.Context, handler Handler) error { continue } - statistics, jobsStarted, err := l.parseMessage(ctx, msg) - if err != nil { - return fmt.Errorf("failed to parse message: %w", err) - } - - l.lastMessageID = msg.MessageId - - if err := l.deleteLastMessage(ctx); err != nil { - return fmt.Errorf("failed to delete message: %w", err) - } - - for _, jobStarted := range jobsStarted { - if err := handler.HandleJobStarted(ctx, jobStarted); err != nil { - return fmt.Errorf("failed to handle job started: %w", err) - } - } - - if err := handler.HandleDesiredRunnerCount(ctx, statistics.TotalAssignedJobs); err != nil { - return fmt.Errorf("failed to handle desired runner count: %w", err) + // New context is created to avoid cancelation during message handling. + if err := l.handleMessage(context.Background(), handler, msg); err != nil { + return fmt.Errorf("failed to handle message: %w", err) } } } +func (l *Listener) handleMessage(ctx context.Context, handler Handler, msg *actions.RunnerScaleSetMessage) error { + parsedMsg, err := l.parseMessage(ctx, msg) + if err != nil { + return fmt.Errorf("failed to parse message: %w", err) + } + l.metrics.PublishStatistics(parsedMsg.statistics) + + if len(parsedMsg.jobsAvailable) > 0 { + acquiredJobIDs, err := l.acquireAvailableJobs(ctx, parsedMsg.jobsAvailable) + if err != nil { + return fmt.Errorf("failed to acquire jobs: %w", err) + } + + l.logger.Info("Jobs are acquired", "count", len(acquiredJobIDs), "requestIds", fmt.Sprint(acquiredJobIDs)) + } + + for _, jobCompleted := range parsedMsg.jobsCompleted { + l.metrics.PublishJobCompleted(jobCompleted) + } + + l.lastMessageID = msg.MessageId + + if err := l.deleteLastMessage(ctx); err != nil { + return fmt.Errorf("failed to delete message: %w", err) + } + + for _, jobStarted := range parsedMsg.jobsStarted { + if err := handler.HandleJobStarted(ctx, jobStarted); err != nil { + return fmt.Errorf("failed to handle job started: %w", err) + } + l.metrics.PublishJobStarted(jobStarted) + } + + desiredRunners, err := handler.HandleDesiredRunnerCount(ctx, parsedMsg.statistics.TotalAssignedJobs) + if err != nil { + return fmt.Errorf("failed to handle desired runner count: %w", err) + } + l.metrics.PublishDesiredRunners(desiredRunners) + return nil +} + func (l *Listener) createSession(ctx context.Context) error { var session *actions.RunnerScaleSetSession var retries int @@ -271,48 +289,57 @@ func (l *Listener) deleteLastMessage(ctx context.Context) error { return nil } -func (l *Listener) parseMessage(ctx context.Context, msg *actions.RunnerScaleSetMessage) (*actions.RunnerScaleSetStatistic, []*actions.JobStarted, error) { +type parsedMessage struct { + statistics *actions.RunnerScaleSetStatistic + jobsStarted []*actions.JobStarted + jobsAvailable []*actions.JobAvailable + jobsCompleted []*actions.JobCompleted +} + +func (l *Listener) parseMessage(ctx context.Context, msg *actions.RunnerScaleSetMessage) (*parsedMessage, error) { + if msg.MessageType != "RunnerScaleSetJobMessages" { + l.logger.Info("Skipping message", "messageType", msg.MessageType) + return nil, fmt.Errorf("invalid message type: %s", msg.MessageType) + } + l.logger.Info("Processing message", "messageId", msg.MessageId, "messageType", msg.MessageType) if msg.Statistics == nil { - return nil, nil, fmt.Errorf("invalid message: statistics is nil") + return nil, fmt.Errorf("invalid message: statistics is nil") } l.logger.Info("New runner scale set statistics.", "statistics", msg.Statistics) - if msg.MessageType != "RunnerScaleSetJobMessages" { - l.logger.Info("Skipping message", "messageType", msg.MessageType) - return nil, nil, fmt.Errorf("invalid message type: %s", msg.MessageType) - } - var batchedMessages []json.RawMessage if len(msg.Body) > 0 { if err := json.Unmarshal([]byte(msg.Body), &batchedMessages); err != nil { - return nil, nil, fmt.Errorf("failed to unmarshal batched messages: %w", err) + return nil, fmt.Errorf("failed to unmarshal batched messages: %w", err) } } - var availableJobs []int64 - var startedJobs []*actions.JobStarted + parsedMsg := &parsedMessage{ + statistics: msg.Statistics, + } + for _, msg := range batchedMessages { var messageType actions.JobMessageType if err := json.Unmarshal(msg, &messageType); err != nil { - return nil, nil, fmt.Errorf("failed to decode job message type: %w", err) + return nil, fmt.Errorf("failed to decode job message type: %w", err) } switch messageType.MessageType { case messageTypeJobAvailable: var jobAvailable actions.JobAvailable if err := json.Unmarshal(msg, &jobAvailable); err != nil { - return nil, nil, fmt.Errorf("failed to decode job available: %w", err) + return nil, fmt.Errorf("failed to decode job available: %w", err) } l.logger.Info("Job available message received", "jobId", jobAvailable.RunnerRequestId) - availableJobs = append(availableJobs, jobAvailable.RunnerRequestId) + parsedMsg.jobsAvailable = append(parsedMsg.jobsAvailable, &jobAvailable) case messageTypeJobAssigned: var jobAssigned actions.JobAssigned if err := json.Unmarshal(msg, &jobAssigned); err != nil { - return nil, nil, fmt.Errorf("failed to decode job assigned: %w", err) + return nil, fmt.Errorf("failed to decode job assigned: %w", err) } l.logger.Info("Job assigned message received", "jobId", jobAssigned.RunnerRequestId) @@ -320,41 +347,37 @@ func (l *Listener) parseMessage(ctx context.Context, msg *actions.RunnerScaleSet case messageTypeJobStarted: var jobStarted actions.JobStarted if err := json.Unmarshal(msg, &jobStarted); err != nil { - return nil, nil, fmt.Errorf("could not decode job started message. %w", err) + return nil, fmt.Errorf("could not decode job started message. %w", err) } l.logger.Info("Job started message received.", "RequestId", jobStarted.RunnerRequestId, "RunnerId", jobStarted.RunnerId) - startedJobs = append(startedJobs, &jobStarted) + parsedMsg.jobsStarted = append(parsedMsg.jobsStarted, &jobStarted) case messageTypeJobCompleted: var jobCompleted actions.JobCompleted if err := json.Unmarshal(msg, &jobCompleted); err != nil { - return nil, nil, fmt.Errorf("failed to decode job completed: %w", err) + return nil, fmt.Errorf("failed to decode job completed: %w", err) } l.logger.Info("Job completed message received.", "RequestId", jobCompleted.RunnerRequestId, "Result", jobCompleted.Result, "RunnerId", jobCompleted.RunnerId, "RunnerName", jobCompleted.RunnerName) + parsedMsg.jobsCompleted = append(parsedMsg.jobsCompleted, &jobCompleted) default: l.logger.Info("unknown job message type.", "messageType", messageType.MessageType) } } - l.logger.Info("Available jobs.", "count", len(availableJobs), "requestIds", fmt.Sprint(availableJobs)) - if len(availableJobs) > 0 { - acquired, err := l.acquireAvailableJobs(ctx, availableJobs) - if err != nil { - return nil, nil, err - } - - l.logger.Info("Jobs are acquired", "count", len(acquired), "requestIds", fmt.Sprint(acquired)) - } - - return msg.Statistics, startedJobs, nil + return parsedMsg, nil } -func (l *Listener) acquireAvailableJobs(ctx context.Context, availableJobs []int64) ([]int64, error) { - l.logger.Info("Acquiring jobs") +func (l *Listener) acquireAvailableJobs(ctx context.Context, jobsAvailable []*actions.JobAvailable) ([]int64, error) { + ids := make([]int64, 0, len(jobsAvailable)) + for _, job := range jobsAvailable { + ids = append(ids, job.RunnerRequestId) + } - ids, err := l.client.AcquireJobs(ctx, l.scaleSetID, l.session.MessageQueueAccessToken, availableJobs) + l.logger.Info("Acquiring jobs", "count", len(ids), "requestIds", fmt.Sprint(ids)) + + ids, err := l.client.AcquireJobs(ctx, l.scaleSetID, l.session.MessageQueueAccessToken, ids) if err == nil { // if NO errors return ids, nil } @@ -368,7 +391,7 @@ func (l *Listener) acquireAvailableJobs(ctx context.Context, availableJobs []int return nil, err } - ids, err = l.client.AcquireJobs(ctx, l.scaleSetID, l.session.MessageQueueAccessToken, availableJobs) + ids, err = l.client.AcquireJobs(ctx, l.scaleSetID, l.session.MessageQueueAccessToken, ids) if err != nil { return nil, fmt.Errorf("failed to acquire jobs after session refresh: %w", err) } diff --git a/cmd/ghalistener/listener/listener_test.go b/cmd/ghalistener/listener/listener_test.go index 86b69b83..df81ee94 100644 --- a/cmd/ghalistener/listener/listener_test.go +++ b/cmd/ghalistener/listener/listener_test.go @@ -2,6 +2,7 @@ package listener import ( "context" + "encoding/json" "errors" "net/http" "testing" @@ -9,7 +10,6 @@ import ( listenermocks "github.com/actions/actions-runner-controller/cmd/ghalistener/listener/mocks" "github.com/actions/actions-runner-controller/cmd/ghalistener/metrics" - metricsmocks "github.com/actions/actions-runner-controller/cmd/ghalistener/metrics/mocks" "github.com/actions/actions-runner-controller/github/actions" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -38,22 +38,6 @@ func TestNew(t *testing.T) { assert.NotNil(t, l) }) - t.Run("SetStaticMetrics", func(t *testing.T) { - t.Parallel() - - metrics := metricsmocks.NewPublisher(t) - - metrics.On("PublishStatic", mock.Anything, mock.Anything).Once() - - config := Config{ - Client: listenermocks.NewClient(t), - ScaleSetID: 1, - Metrics: metrics, - } - l, err := New(config) - assert.Nil(t, err) - assert.NotNil(t, l) - }) } func TestListener_createSession(t *testing.T) { @@ -443,7 +427,7 @@ func TestListener_Listen(t *testing.T) { var called bool handler := listenermocks.NewHandler(t) handler.On("HandleDesiredRunnerCount", mock.Anything, mock.Anything). - Return(nil). + Return(0, nil). Run( func(mock.Arguments) { called = true @@ -456,6 +440,63 @@ func TestListener_Listen(t *testing.T) { assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, called) }) + + t.Run("CancelContextAfterGetMessage", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + uuid := uuid.New() + session := &actions.RunnerScaleSetSession{ + SessionId: &uuid, + OwnerName: "example", + RunnerScaleSet: &actions.RunnerScaleSet{}, + MessageQueueUrl: "https://example.com", + MessageQueueAccessToken: "1234567890", + Statistics: &actions.RunnerScaleSetStatistic{}, + } + client.On("CreateMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() + + msg := &actions.RunnerScaleSetMessage{ + MessageId: 1, + MessageType: "RunnerScaleSetJobMessages", + Statistics: &actions.RunnerScaleSetStatistic{}, + } + client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything). + Return(msg, nil). + Run( + func(mock.Arguments) { + cancel() + }, + ). + Once() + + // Ensure delete message is called with background context + client.On("DeleteMessage", context.Background(), mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + + config.Client = client + + handler := listenermocks.NewHandler(t) + handler.On("HandleDesiredRunnerCount", mock.Anything, mock.Anything). + Return(0, nil). + Once() + + handler.On("HandleDesiredRunnerCount", mock.Anything, mock.Anything). + Return(0, nil). + Once() + + l, err := New(config) + require.Nil(t, err) + + err = l.Listen(ctx, handler) + assert.ErrorIs(t, context.Canceled, err) + }) } func TestListener_acquireAvailableJobs(t *testing.T) { @@ -489,7 +530,24 @@ func TestListener_acquireAvailableJobs(t *testing.T) { Statistics: &actions.RunnerScaleSetStatistic{}, } - _, err = l.acquireAvailableJobs(ctx, []int64{1, 2, 3}) + availableJobs := []*actions.JobAvailable{ + { + JobMessageBase: actions.JobMessageBase{ + RunnerRequestId: 1, + }, + }, + { + JobMessageBase: actions.JobMessageBase{ + RunnerRequestId: 2, + }, + }, + { + JobMessageBase: actions.JobMessageBase{ + RunnerRequestId: 3, + }, + }, + } + _, err = l.acquireAvailableJobs(ctx, availableJobs) assert.Error(t, err) }) @@ -523,9 +581,26 @@ func TestListener_acquireAvailableJobs(t *testing.T) { Statistics: &actions.RunnerScaleSetStatistic{}, } - acquiredJobIDs, err := l.acquireAvailableJobs(ctx, []int64{1, 2, 3}) + availableJobs := []*actions.JobAvailable{ + { + JobMessageBase: actions.JobMessageBase{ + RunnerRequestId: 1, + }, + }, + { + JobMessageBase: actions.JobMessageBase{ + RunnerRequestId: 2, + }, + }, + { + JobMessageBase: actions.JobMessageBase{ + RunnerRequestId: 3, + }, + }, + } + acquiredJobIDs, err := l.acquireAvailableJobs(ctx, availableJobs) assert.NoError(t, err) - assert.Equal(t, jobIDs, acquiredJobIDs) + assert.Equal(t, []int64{1, 2, 3}, acquiredJobIDs) }) t.Run("RefreshAndSucceeds", func(t *testing.T) { @@ -555,6 +630,23 @@ func TestListener_acquireAvailableJobs(t *testing.T) { // Second call to AcquireJobs will succeed want := []int64{1, 2, 3} + availableJobs := []*actions.JobAvailable{ + { + JobMessageBase: actions.JobMessageBase{ + RunnerRequestId: 1, + }, + }, + { + JobMessageBase: actions.JobMessageBase{ + RunnerRequestId: 2, + }, + }, + { + JobMessageBase: actions.JobMessageBase{ + RunnerRequestId: 3, + }, + }, + } client.On("AcquireJobs", ctx, mock.Anything, mock.Anything, mock.Anything).Return(want, nil).Once() config.Client = client @@ -567,7 +659,7 @@ func TestListener_acquireAvailableJobs(t *testing.T) { RunnerScaleSet: &actions.RunnerScaleSet{}, } - got, err := l.acquireAvailableJobs(ctx, want) + got, err := l.acquireAvailableJobs(ctx, availableJobs) assert.Nil(t, err) assert.Equal(t, want, got) }) @@ -606,8 +698,165 @@ func TestListener_acquireAvailableJobs(t *testing.T) { RunnerScaleSet: &actions.RunnerScaleSet{}, } - got, err := l.acquireAvailableJobs(ctx, []int64{1, 2, 3}) + availableJobs := []*actions.JobAvailable{ + { + JobMessageBase: actions.JobMessageBase{ + RunnerRequestId: 1, + }, + }, + { + JobMessageBase: actions.JobMessageBase{ + RunnerRequestId: 2, + }, + }, + { + JobMessageBase: actions.JobMessageBase{ + RunnerRequestId: 3, + }, + }, + } + + got, err := l.acquireAvailableJobs(ctx, availableJobs) assert.NotNil(t, err) assert.Nil(t, got) }) } + +func TestListener_parseMessage(t *testing.T) { + t.Run("FailOnEmptyStatistics", func(t *testing.T) { + msg := &actions.RunnerScaleSetMessage{ + MessageId: 1, + MessageType: "RunnerScaleSetJobMessages", + Statistics: nil, + } + + l := &Listener{} + parsedMsg, err := l.parseMessage(context.Background(), msg) + assert.Error(t, err) + assert.Nil(t, parsedMsg) + }) + + t.Run("FailOnIncorrectMessageType", func(t *testing.T) { + msg := &actions.RunnerScaleSetMessage{ + MessageId: 1, + MessageType: "RunnerMessages", // arbitrary message type + Statistics: &actions.RunnerScaleSetStatistic{}, + } + + l := &Listener{} + parsedMsg, err := l.parseMessage(context.Background(), msg) + assert.Error(t, err) + assert.Nil(t, parsedMsg) + }) + + t.Run("ParseAll", func(t *testing.T) { + msg := &actions.RunnerScaleSetMessage{ + MessageId: 1, + MessageType: "RunnerScaleSetJobMessages", + Body: "", + Statistics: &actions.RunnerScaleSetStatistic{ + TotalAvailableJobs: 1, + TotalAcquiredJobs: 2, + TotalAssignedJobs: 3, + TotalRunningJobs: 4, + TotalRegisteredRunners: 5, + TotalBusyRunners: 6, + TotalIdleRunners: 7, + }, + } + + var batchedMessages []any + jobsAvailable := []*actions.JobAvailable{ + { + AcquireJobUrl: "https://github.com/example", + JobMessageBase: actions.JobMessageBase{ + JobMessageType: actions.JobMessageType{ + MessageType: messageTypeJobAvailable, + }, + RunnerRequestId: 1, + }, + }, + { + AcquireJobUrl: "https://github.com/example", + JobMessageBase: actions.JobMessageBase{ + JobMessageType: actions.JobMessageType{ + MessageType: messageTypeJobAvailable, + }, + RunnerRequestId: 2, + }, + }, + } + for _, msg := range jobsAvailable { + batchedMessages = append(batchedMessages, msg) + } + + jobsAssigned := []*actions.JobAssigned{ + { + JobMessageBase: actions.JobMessageBase{ + JobMessageType: actions.JobMessageType{ + MessageType: messageTypeJobAssigned, + }, + RunnerRequestId: 3, + }, + }, + { + JobMessageBase: actions.JobMessageBase{ + JobMessageType: actions.JobMessageType{ + MessageType: messageTypeJobAssigned, + }, + RunnerRequestId: 4, + }, + }, + } + for _, msg := range jobsAssigned { + batchedMessages = append(batchedMessages, msg) + } + + jobsStarted := []*actions.JobStarted{ + { + JobMessageBase: actions.JobMessageBase{ + JobMessageType: actions.JobMessageType{ + MessageType: messageTypeJobStarted, + }, + RunnerRequestId: 5, + }, + RunnerId: 2, + RunnerName: "runner2", + }, + } + for _, msg := range jobsStarted { + batchedMessages = append(batchedMessages, msg) + } + + jobsCompleted := []*actions.JobCompleted{ + { + JobMessageBase: actions.JobMessageBase{ + JobMessageType: actions.JobMessageType{ + MessageType: messageTypeJobCompleted, + }, + RunnerRequestId: 6, + }, + Result: "success", + RunnerId: 1, + RunnerName: "runner1", + }, + } + for _, msg := range jobsCompleted { + batchedMessages = append(batchedMessages, msg) + } + + b, err := json.Marshal(batchedMessages) + require.NoError(t, err) + + msg.Body = string(b) + + l := &Listener{} + parsedMsg, err := l.parseMessage(context.Background(), msg) + require.NoError(t, err) + + assert.Equal(t, msg.Statistics, parsedMsg.statistics) + assert.Equal(t, jobsAvailable, parsedMsg.jobsAvailable) + assert.Equal(t, jobsStarted, parsedMsg.jobsStarted) + assert.Equal(t, jobsCompleted, parsedMsg.jobsCompleted) + }) +} diff --git a/cmd/ghalistener/listener/metrics_test.go b/cmd/ghalistener/listener/metrics_test.go new file mode 100644 index 00000000..96ca6ac2 --- /dev/null +++ b/cmd/ghalistener/listener/metrics_test.go @@ -0,0 +1,204 @@ +package listener + +import ( + "context" + "encoding/json" + "testing" + + listenermocks "github.com/actions/actions-runner-controller/cmd/ghalistener/listener/mocks" + metricsmocks "github.com/actions/actions-runner-controller/cmd/ghalistener/metrics/mocks" + "github.com/actions/actions-runner-controller/github/actions" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestInitialMetrics(t *testing.T) { + t.Parallel() + + t.Run("SetStaticMetrics", func(t *testing.T) { + t.Parallel() + + metrics := metricsmocks.NewPublisher(t) + + minRunners := 5 + maxRunners := 10 + metrics.On("PublishStatic", minRunners, maxRunners).Once() + + config := Config{ + Client: listenermocks.NewClient(t), + ScaleSetID: 1, + Metrics: metrics, + MinRunners: minRunners, + MaxRunners: maxRunners, + } + l, err := New(config) + + assert.Nil(t, err) + assert.NotNil(t, l) + }) + + t.Run("InitialMessageStatistics", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + + sessionStatistics := &actions.RunnerScaleSetStatistic{ + TotalAvailableJobs: 1, + TotalAcquiredJobs: 2, + TotalAssignedJobs: 3, + TotalRunningJobs: 4, + TotalRegisteredRunners: 5, + TotalBusyRunners: 6, + TotalIdleRunners: 7, + } + + uuid := uuid.New() + session := &actions.RunnerScaleSetSession{ + SessionId: &uuid, + OwnerName: "example", + RunnerScaleSet: &actions.RunnerScaleSet{}, + MessageQueueUrl: "https://example.com", + MessageQueueAccessToken: "1234567890", + Statistics: sessionStatistics, + } + + metrics := metricsmocks.NewPublisher(t) + metrics.On("PublishStatic", mock.Anything, mock.Anything).Once() + metrics.On("PublishStatistics", sessionStatistics).Once() + metrics.On("PublishDesiredRunners", sessionStatistics.TotalAssignedJobs). + Run( + func(mock.Arguments) { + cancel() + }, + ).Once() + + config := Config{ + Client: listenermocks.NewClient(t), + ScaleSetID: 1, + Metrics: metrics, + } + + client := listenermocks.NewClient(t) + client.On("CreateMessageSession", mock.Anything, mock.Anything, mock.Anything).Return(session, nil).Once() + config.Client = client + + handler := listenermocks.NewHandler(t) + handler.On("HandleDesiredRunnerCount", mock.Anything, sessionStatistics.TotalAssignedJobs). + Return(sessionStatistics.TotalAssignedJobs, nil). + Once() + + l, err := New(config) + assert.Nil(t, err) + assert.NotNil(t, l) + + assert.ErrorIs(t, context.Canceled, l.Listen(ctx, handler)) + }) +} + +func TestHandleMessageMetrics(t *testing.T) { + t.Parallel() + + msg := &actions.RunnerScaleSetMessage{ + MessageId: 1, + MessageType: "RunnerScaleSetJobMessages", + Body: "", + Statistics: &actions.RunnerScaleSetStatistic{ + TotalAvailableJobs: 1, + TotalAcquiredJobs: 2, + TotalAssignedJobs: 3, + TotalRunningJobs: 4, + TotalRegisteredRunners: 5, + TotalBusyRunners: 6, + TotalIdleRunners: 7, + }, + } + + var batchedMessages []any + jobsStarted := []*actions.JobStarted{ + { + JobMessageBase: actions.JobMessageBase{ + JobMessageType: actions.JobMessageType{ + MessageType: messageTypeJobStarted, + }, + RunnerRequestId: 8, + }, + RunnerId: 3, + RunnerName: "runner3", + }, + } + for _, msg := range jobsStarted { + batchedMessages = append(batchedMessages, msg) + } + + jobsCompleted := []*actions.JobCompleted{ + { + JobMessageBase: actions.JobMessageBase{ + JobMessageType: actions.JobMessageType{ + MessageType: messageTypeJobCompleted, + }, + RunnerRequestId: 6, + }, + Result: "success", + RunnerId: 1, + RunnerName: "runner1", + }, + { + JobMessageBase: actions.JobMessageBase{ + JobMessageType: actions.JobMessageType{ + MessageType: messageTypeJobCompleted, + }, + RunnerRequestId: 7, + }, + Result: "success", + RunnerId: 2, + RunnerName: "runner2", + }, + } + for _, msg := range jobsCompleted { + batchedMessages = append(batchedMessages, msg) + } + + b, err := json.Marshal(batchedMessages) + require.NoError(t, err) + + msg.Body = string(b) + + desiredResult := 4 + + metrics := metricsmocks.NewPublisher(t) + metrics.On("PublishStatic", 0, 0).Once() + metrics.On("PublishStatistics", msg.Statistics).Once() + metrics.On("PublishJobCompleted", jobsCompleted[0]).Once() + metrics.On("PublishJobCompleted", jobsCompleted[1]).Once() + metrics.On("PublishJobStarted", jobsStarted[0]).Once() + metrics.On("PublishDesiredRunners", desiredResult).Once() + + handler := listenermocks.NewHandler(t) + handler.On("HandleJobStarted", mock.Anything, jobsStarted[0]).Return(nil).Once() + handler.On("HandleDesiredRunnerCount", mock.Anything, mock.Anything).Return(desiredResult, nil).Once() + + client := listenermocks.NewClient(t) + client.On("DeleteMessage", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + + config := Config{ + Client: listenermocks.NewClient(t), + ScaleSetID: 1, + Metrics: metrics, + } + + l, err := New(config) + require.NoError(t, err) + l.client = client + l.session = &actions.RunnerScaleSetSession{ + OwnerName: "", + RunnerScaleSet: &actions.RunnerScaleSet{}, + MessageQueueUrl: "", + MessageQueueAccessToken: "", + Statistics: &actions.RunnerScaleSetStatistic{}, + } + + err = l.handleMessage(context.Background(), handler, msg) + require.NoError(t, err) +} diff --git a/cmd/ghalistener/listener/mocks/handler.go b/cmd/ghalistener/listener/mocks/handler.go index c78fe250..edc1b30b 100644 --- a/cmd/ghalistener/listener/mocks/handler.go +++ b/cmd/ghalistener/listener/mocks/handler.go @@ -15,18 +15,28 @@ type Handler struct { mock.Mock } -// HandleDesiredRunnerCount provides a mock function with given fields: ctx, desiredRunnerCount -func (_m *Handler) HandleDesiredRunnerCount(ctx context.Context, desiredRunnerCount int) error { - ret := _m.Called(ctx, desiredRunnerCount) +// HandleDesiredRunnerCount provides a mock function with given fields: ctx, count +func (_m *Handler) HandleDesiredRunnerCount(ctx context.Context, count int) (int, error) { + ret := _m.Called(ctx, count) - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { - r0 = rf(ctx, desiredRunnerCount) + var r0 int + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int) (int, error)); ok { + return rf(ctx, count) + } + if rf, ok := ret.Get(0).(func(context.Context, int) int); ok { + r0 = rf(ctx, count) } else { - r0 = ret.Error(0) + r0 = ret.Get(0).(int) } - return r0 + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, count) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // HandleJobStarted provides a mock function with given fields: ctx, jobInfo diff --git a/cmd/ghalistener/worker/worker.go b/cmd/ghalistener/worker/worker.go index f9d7b7db..169f0251 100644 --- a/cmd/ghalistener/worker/worker.go +++ b/cmd/ghalistener/worker/worker.go @@ -156,7 +156,7 @@ func (w *Worker) HandleJobStarted(ctx context.Context, jobInfo *actions.JobStart // The function then scales the ephemeral runner set by applying the merge patch. // Finally, it logs the scaled ephemeral runner set details and returns nil if successful. // If any error occurs during the process, it returns an error with a descriptive message. -func (w *Worker) HandleDesiredRunnerCount(ctx context.Context, count int) error { +func (w *Worker) HandleDesiredRunnerCount(ctx context.Context, count int) (int, error) { // Max runners should always be set by the resource builder either to the configured value, // or the maximum int32 (resourcebuilder.newAutoScalingListener()). targetRunnerCount := min(w.config.MinRunners+count, w.config.MaxRunners) @@ -171,7 +171,7 @@ func (w *Worker) HandleDesiredRunnerCount(ctx context.Context, count int) error if targetRunnerCount == w.lastPatch { w.logger.Info("Skipping patching of EphemeralRunnerSet as the desired count has not changed", logValues...) - return nil + return targetRunnerCount, nil } original, err := json.Marshal( @@ -182,7 +182,7 @@ func (w *Worker) HandleDesiredRunnerCount(ctx context.Context, count int) error }, ) if err != nil { - return fmt.Errorf("failed to marshal empty ephemeral runner set: %w", err) + return 0, fmt.Errorf("failed to marshal empty ephemeral runner set: %w", err) } patch, err := json.Marshal( @@ -194,12 +194,12 @@ func (w *Worker) HandleDesiredRunnerCount(ctx context.Context, count int) error ) if err != nil { w.logger.Error(err, "could not marshal patch ephemeral runner set") - return err + return 0, err } mergePatch, err := jsonpatch.CreateMergePatch(original, patch) if err != nil { - return fmt.Errorf("failed to create merge patch json for ephemeral runner set: %w", err) + return 0, fmt.Errorf("failed to create merge patch json for ephemeral runner set: %w", err) } w.logger.Info("Created merge patch json for EphemeralRunnerSet update", "json", string(mergePatch)) @@ -217,7 +217,7 @@ func (w *Worker) HandleDesiredRunnerCount(ctx context.Context, count int) error Do(ctx). Into(patchedEphemeralRunnerSet) if err != nil { - return fmt.Errorf("could not patch ephemeral runner set , patch JSON: %s, error: %w", string(mergePatch), err) + return 0, fmt.Errorf("could not patch ephemeral runner set , patch JSON: %s, error: %w", string(mergePatch), err) } w.logger.Info("Ephemeral runner set scaled.", @@ -225,5 +225,5 @@ func (w *Worker) HandleDesiredRunnerCount(ctx context.Context, count int) error "name", w.config.EphemeralRunnerSetName, "replicas", patchedEphemeralRunnerSet.Spec.Replicas, ) - return nil + return targetRunnerCount, nil } diff --git a/controllers/actions.github.com/resourcebuilder.go b/controllers/actions.github.com/resourcebuilder.go index 18f58840..0ee48326 100644 --- a/controllers/actions.github.com/resourcebuilder.go +++ b/controllers/actions.github.com/resourcebuilder.go @@ -226,6 +226,7 @@ func (b *resourceBuilder) newScaleSetListenerPod(autoscalingListener *v1alpha1.A ports = append(ports, port) } + terminationGracePeriodSeconds := int64(60) podSpec := corev1.PodSpec{ ServiceAccountName: serviceAccount.Name, Containers: []corev1.Container{ @@ -256,8 +257,9 @@ func (b *resourceBuilder) newScaleSetListenerPod(autoscalingListener *v1alpha1.A }, }, }, - ImagePullSecrets: autoscalingListener.Spec.ImagePullSecrets, - RestartPolicy: corev1.RestartPolicyNever, + ImagePullSecrets: autoscalingListener.Spec.ImagePullSecrets, + RestartPolicy: corev1.RestartPolicyNever, + TerminationGracePeriodSeconds: &terminationGracePeriodSeconds, } labels := make(map[string]string, len(autoscalingListener.Labels))