diff --git a/github/actions/client.go b/github/actions/client.go index bea3e9b2..73632e56 100644 --- a/github/actions/client.go +++ b/github/actions/client.go @@ -634,7 +634,18 @@ func (c *Client) AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQ } if resp.StatusCode != http.StatusOK { - return nil, ParseActionsErrorFromResponse(resp) + if resp.StatusCode != http.StatusUnauthorized { + return nil, ParseActionsErrorFromResponse(resp) + } + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + body = trimByteOrderMark(body) + if err != nil { + return nil, err + } + + return nil, &MessageQueueTokenExpiredError{msg: string(body)} } var acquiredJobs *Int64List diff --git a/github/actions/client_job_acquisition_test.go b/github/actions/client_job_acquisition_test.go index e11b7ba2..38c81e05 100644 --- a/github/actions/client_job_acquisition_test.go +++ b/github/actions/client_job_acquisition_test.go @@ -2,6 +2,7 @@ package actions_test import ( "context" + "errors" "net/http" "strings" "testing" @@ -84,6 +85,39 @@ func TestAcquireJobs(t *testing.T) { assert.NotNil(t, err) assert.Equalf(t, actualRetry, expectedRetry, "A retry was expected after the first request but got: %v", actualRetry) }) + + t.Run("Should return MessageQueueTokenExpiredError when http error is not Unauthorized", func(t *testing.T) { + want := []int64{1} + + session := &actions.RunnerScaleSetSession{ + RunnerScaleSet: &actions.RunnerScaleSet{Id: 1}, + MessageQueueAccessToken: "abc", + } + requestIDs := want + + server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/acquirablejobs") { + w.Write([]byte(`{"count": 1}`)) + return + } + if r.Method == http.MethodPost { + http.Error(w, "Session expired", http.StatusUnauthorized) + return + } + })) + + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) + require.NoError(t, err) + + _, err = client.GetAcquirableJobs(ctx, 1) + require.NoError(t, err) + + got, err := client.AcquireJobs(ctx, session.RunnerScaleSet.Id, session.MessageQueueAccessToken, requestIDs) + require.Error(t, err) + assert.Nil(t, got) + var expectedErr *actions.MessageQueueTokenExpiredError + assert.True(t, errors.As(err, &expectedErr)) + }) } func TestGetAcquirableJobs(t *testing.T) {