From 3bda9bb2402b361e3bab78fdcd2b67da1c4e3638 Mon Sep 17 00:00:00 2001 From: Nikola Jokic Date: Fri, 17 May 2024 15:16:38 +0200 Subject: [PATCH] Refresh session if token expires during delete message (#3529) --- cmd/ghalistener/listener/listener.go | 19 ++++- cmd/ghalistener/listener/listener_test.go | 87 +++++++++++++++++++++++ 2 files changed, 104 insertions(+), 2 deletions(-) diff --git a/cmd/ghalistener/listener/listener.go b/cmd/ghalistener/listener/listener.go index 79009726..a9cf0838 100644 --- a/cmd/ghalistener/listener/listener.go +++ b/cmd/ghalistener/listener/listener.go @@ -295,8 +295,23 @@ func (l *Listener) getMessage(ctx context.Context) (*actions.RunnerScaleSetMessa func (l *Listener) deleteLastMessage(ctx context.Context) error { l.logger.Info("Deleting last message", "lastMessageID", l.lastMessageID) - if err := l.client.DeleteMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID); err != nil { - return fmt.Errorf("failed to delete message: %w", err) + err := l.client.DeleteMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID) + if err == nil { // if NO error + return nil + } + + expiredError := &actions.MessageQueueTokenExpiredError{} + if !errors.As(err, &expiredError) { + return fmt.Errorf("failed to delete last message: %w", err) + } + + if err := l.refreshSession(ctx); err != nil { + return err + } + + err = l.client.DeleteMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID) + if err != nil { + return fmt.Errorf("failed to delete last message after message session refresh: %w", err) } return nil diff --git a/cmd/ghalistener/listener/listener_test.go b/cmd/ghalistener/listener/listener_test.go index d2af2b03..38c1f40f 100644 --- a/cmd/ghalistener/listener/listener_test.go +++ b/cmd/ghalistener/listener/listener_test.go @@ -377,6 +377,93 @@ func TestListener_deleteLastMessage(t *testing.T) { err = l.deleteLastMessage(ctx) assert.NotNil(t, err) }) + + t.Run("RefreshAndSucceeds", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + + newUUID := uuid.New() + session := &actions.RunnerScaleSetSession{ + SessionId: &newUUID, + OwnerName: "example", + RunnerScaleSet: &actions.RunnerScaleSet{}, + MessageQueueUrl: "https://example.com", + MessageQueueAccessToken: "1234567890", + Statistics: nil, + } + client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() + + client.On("DeleteMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(&actions.MessageQueueTokenExpiredError{}).Once() + + client.On("DeleteMessage", ctx, mock.Anything, mock.Anything, mock.MatchedBy(func(lastMessageID any) bool { + return lastMessageID.(int64) == int64(5) + })).Return(nil).Once() + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + oldUUID := uuid.New() + l.session = &actions.RunnerScaleSetSession{ + SessionId: &oldUUID, + RunnerScaleSet: &actions.RunnerScaleSet{}, + } + l.lastMessageID = 5 + + config.Client = client + + err = l.deleteLastMessage(ctx) + assert.NoError(t, err) + }) + + t.Run("RefreshAndFails", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + + newUUID := uuid.New() + session := &actions.RunnerScaleSetSession{ + SessionId: &newUUID, + OwnerName: "example", + RunnerScaleSet: &actions.RunnerScaleSet{}, + MessageQueueUrl: "https://example.com", + MessageQueueAccessToken: "1234567890", + Statistics: nil, + } + client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() + + client.On("DeleteMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(&actions.MessageQueueTokenExpiredError{}).Twice() + + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + oldUUID := uuid.New() + l.session = &actions.RunnerScaleSetSession{ + SessionId: &oldUUID, + RunnerScaleSet: &actions.RunnerScaleSet{}, + } + l.lastMessageID = 5 + + config.Client = client + + err = l.deleteLastMessage(ctx) + assert.Error(t, err) + }) } func TestListener_Listen(t *testing.T) {