diff --git a/cmd/githubrunnerscalesetlistener/main.go b/cmd/githubrunnerscalesetlistener/main.go index 06684597..ed9f145f 100644 --- a/cmd/githubrunnerscalesetlistener/main.go +++ b/cmd/githubrunnerscalesetlistener/main.go @@ -85,7 +85,6 @@ func run(rc RunnerScaleSetListenerConfig, logger logr.Logger) error { } actionsServiceClient, err := actions.NewClient( - ctx, rc.ConfigureUrl, creds, actions.WithUserAgent(fmt.Sprintf("actions-runner-controller/%s", build.Version)), diff --git a/github/actions/actions_server_test.go b/github/actions/actions_server_test.go index 8b66e45f..2638435a 100644 --- a/github/actions/actions_server_test.go +++ b/github/actions/actions_server_test.go @@ -17,9 +17,20 @@ import ( // /actions/runner-registration endpoints will be handled by the provided // handler. The returned server is started and will be automatically closed // when the test ends. -func newActionsServer(t *testing.T, handler http.Handler) *actionsServer { - var u string - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func newActionsServer(t *testing.T, handler http.Handler, options ...actionsServerOption) *actionsServer { + s := httptest.NewServer(nil) + server := &actionsServer{ + Server: s, + } + t.Cleanup(func() { + server.Close() + }) + + for _, option := range options { + option(server) + } + + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // handle getRunnerRegistrationToken if strings.HasSuffix(r.URL.Path, "/runners/registration-token") { w.WriteHeader(http.StatusCreated) @@ -29,41 +40,55 @@ func newActionsServer(t *testing.T, handler http.Handler) *actionsServer { // handle getActionsServiceAdminConnection if strings.HasSuffix(r.URL.Path, "/actions/runner-registration") { - claims := &jwt.RegisteredClaims{ - IssuedAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)), - ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Minute)), - Issuer: "123", + if server.token == "" { + server.token = defaultActionsToken(t) } - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(samplePrivateKey)) - require.NoError(t, err) - tokenString, err := token.SignedString(privateKey) - require.NoError(t, err) - w.Write([]byte(`{"url":"` + u + `","token":"` + tokenString + `"}`)) + w.Write([]byte(`{"url":"` + s.URL + `/tenant/123/","token":"` + server.token + `"}`)) return } handler.ServeHTTP(w, r) - })) - - u = server.URL - - t.Cleanup(func() { - server.Close() }) - return &actionsServer{server} + server.Config.Handler = h + + return server +} + +type actionsServerOption func(*actionsServer) + +func withActionsToken(token string) actionsServerOption { + return func(s *actionsServer) { + s.token = token + } } type actionsServer struct { *httptest.Server + + token string } func (s *actionsServer) configURLForOrg(org string) string { return s.URL + "/" + org } +func defaultActionsToken(t *testing.T) string { + claims := &jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute)), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), + Issuer: "123", + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(samplePrivateKey)) + require.NoError(t, err) + tokenString, err := token.SignedString(privateKey) + require.NoError(t, err) + return tokenString +} + const samplePrivateKey = `-----BEGIN RSA PRIVATE KEY----- MIICWgIBAAKBgHXfRT9cv9UY9fAAD4+1RshpfSSZe277urfEmPfX3/Og9zJYRk// CZrJVD1CaBZDiIyQsNEzjta7r4UsqWdFOggiNN2E7ZTFQjMSaFkVgrzHqWuiaCBf diff --git a/github/actions/byte_order_mark_test.go b/github/actions/byte_order_mark_test.go new file mode 100644 index 00000000..107dd92a --- /dev/null +++ b/github/actions/byte_order_mark_test.go @@ -0,0 +1,61 @@ +package actions_test + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/actions/actions-runner-controller/github/actions" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestClient_Do(t *testing.T) { + t.Run("trims byte order mark from response if present", func(t *testing.T) { + t.Run("when there is no body", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + defer server.Close() + + client, err := actions.NewClient("https://localhost/org/repo", &actions.ActionsAuth{Token: "token"}) + require.NoError(t, err) + + req, err := http.NewRequest("GET", server.URL, nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Empty(t, string(body)) + }) + + responses := []string{ + "\xef\xbb\xbf{\"foo\":\"bar\"}", + "{\"foo\":\"bar\"}", + } + + for _, response := range responses { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(response)) + })) + defer server.Close() + + client, err := actions.NewClient("https://localhost/org/repo", &actions.ActionsAuth{Token: "token"}) + require.NoError(t, err) + + req, err := http.NewRequest("GET", server.URL, nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "{\"foo\":\"bar\"}", string(body)) + } + }) +} diff --git a/github/actions/client.go b/github/actions/client.go index fdadc70b..70b02b14 100644 --- a/github/actions/client.go +++ b/github/actions/client.go @@ -12,9 +12,7 @@ import ( "log" "net/http" "net/url" - "path" "strconv" - "strings" "sync" "time" @@ -62,17 +60,17 @@ type Client struct { mu sync.Mutex // TODO: Convert to unexported fields once refactor of Listener is complete - ActionsServiceAdminToken *string - ActionsServiceAdminTokenExpiresAt *time.Time - ActionsServiceURL *string + ActionsServiceAdminToken string + ActionsServiceAdminTokenExpiresAt time.Time + ActionsServiceURL string retryMax int retryWaitMax time.Duration - creds *ActionsAuth - githubConfigURL string - logger logr.Logger - userAgent string + creds *ActionsAuth + config *GitHubConfig + logger logr.Logger + userAgent string rootCAs *x509.CertPool tlsInsecureSkipVerify bool @@ -116,11 +114,16 @@ func WithoutTLSVerify() ClientOption { } } -func NewClient(ctx context.Context, githubConfigURL string, creds *ActionsAuth, options ...ClientOption) (ActionsService, error) { +func NewClient(githubConfigURL string, creds *ActionsAuth, options ...ClientOption) (*Client, error) { + config, err := ParseGitHubConfigFromURL(githubConfigURL) + if err != nil { + return nil, fmt.Errorf("failed to parse githubConfigURL: %w", err) + } + ac := &Client{ - creds: creds, - githubConfigURL: githubConfigURL, - logger: logr.Discard(), + creds: creds, + config: config, + logger: logr.Discard(), // retryablehttp defaults retryMax: 4, @@ -132,9 +135,6 @@ func NewClient(ctx context.Context, githubConfigURL string, creds *ActionsAuth, } retryClient := retryablehttp.NewClient() - - // TODO: this silences retryclient default logger, do we want to provide one - // instead? by default retryablehttp logs all requests to stderr retryClient.Logger = log.New(io.Discard, "", log.LstdFlags) retryClient.RetryMax = ac.retryMax @@ -161,48 +161,94 @@ func NewClient(ctx context.Context, githubConfigURL string, creds *ActionsAuth, retryClient.HTTPClient.Transport = transport ac.Client = retryClient.StandardClient() - rt, err := ac.getRunnerRegistrationToken(ctx, githubConfigURL, *creds) - if err != nil { - return nil, fmt.Errorf("failed to get runner registration token: %w", err) - } - - adminConnInfo, err := ac.getActionsServiceAdminConnection(ctx, rt, githubConfigURL) - if err != nil { - return nil, fmt.Errorf("failed to get actions service admin connection: %w", err) - } - - ac.ActionsServiceURL = adminConnInfo.ActionsServiceUrl - - ac.mu.Lock() - defer ac.mu.Unlock() - ac.ActionsServiceAdminToken = adminConnInfo.AdminToken - ac.ActionsServiceAdminTokenExpiresAt, err = actionsServiceAdminTokenExpiresAt(*adminConnInfo.AdminToken) - if err != nil { - return nil, fmt.Errorf("failed to get admin token expire at: %w", err) - } - return ac, nil } -func (c *Client) GetRunnerScaleSet(ctx context.Context, runnerScaleSetName string) (*RunnerScaleSet, error) { - u := fmt.Sprintf("%s/%s?name=%s&api-version=6.0-preview", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetName) - - if err := c.refreshTokenIfNeeded(ctx); err != nil { - return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err) +func (c *Client) Do(req *http.Request) (*http.Response, error) { + resp, err := c.Client.Do(req) + if err != nil { + return nil, err } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + err = resp.Body.Close() + if err != nil { + return nil, err + } + + body = trimByteOrderMark(body) + resp.Body = io.NopCloser(bytes.NewReader(body)) + return resp, nil +} + +func (c *Client) NewGitHubAPIRequest(ctx context.Context, method, path string, body io.Reader) (*http.Request, error) { + u := c.config.GitHubAPIURL(path) + req, err := http.NewRequestWithContext(ctx, method, u.String(), body) + if err != nil { + return nil, err + } + + if c.userAgent != "" { + req.Header.Set("User-Agent", c.userAgent) + } + + return req, nil +} + +func (c *Client) NewActionsServiceRequest(ctx context.Context, method, path string, body io.Reader) (*http.Request, error) { + err := c.updateTokenIfNeeded(ctx) + if err != nil { + return nil, err + } + + parsedPath, err := url.Parse(path) + if err != nil { + return nil, err + } + + urlString, err := url.JoinPath(c.ActionsServiceURL, parsedPath.Path) + if err != nil { + return nil, err + } + + u, err := url.Parse(urlString) + if err != nil { + return nil, err + } + + q := u.Query() + for k, v := range parsedPath.Query() { + q[k] = v + } + if q.Get("api-version") == "" { + q.Set("api-version", "6.0-preview") + } + u.RawQuery = q.Encode() + + req, err := http.NewRequestWithContext(ctx, method, u.String(), body) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken)) - + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.ActionsServiceAdminToken)) if c.userAgent != "" { req.Header.Set("User-Agent", c.userAgent) } + return req, nil +} + +func (c *Client) GetRunnerScaleSet(ctx context.Context, runnerScaleSetName string) (*RunnerScaleSet, error) { + path := fmt.Sprintf("/%s?name=%s", scaleSetEndpoint, runnerScaleSetName) + req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil) + if err != nil { + return nil, err + } + resp, err := c.Do(req) if err != nil { return nil, err @@ -211,8 +257,9 @@ func (c *Client) GetRunnerScaleSet(ctx context.Context, runnerScaleSetName strin if resp.StatusCode != http.StatusOK { return nil, ParseActionsErrorFromResponse(resp) } + var runnerScaleSetList *runnerScaleSetsResponse - err = unmarshalBody(resp, &runnerScaleSetList) + err = json.NewDecoder(resp.Body).Decode(&runnerScaleSetList) if err != nil { return nil, err } @@ -227,24 +274,12 @@ func (c *Client) GetRunnerScaleSet(ctx context.Context, runnerScaleSetName strin } func (c *Client) GetRunnerScaleSetById(ctx context.Context, runnerScaleSetId int) (*RunnerScaleSet, error) { - u := fmt.Sprintf("%s/%s/%d?api-version=6.0-preview", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId) - - if err := c.refreshTokenIfNeeded(ctx); err != nil { - return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + path := fmt.Sprintf("/%s/%d", scaleSetEndpoint, runnerScaleSetId) + req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken)) - - if c.userAgent != "" { - req.Header.Set("User-Agent", c.userAgent) - } - resp, err := c.Do(req) if err != nil { return nil, err @@ -255,7 +290,7 @@ func (c *Client) GetRunnerScaleSetById(ctx context.Context, runnerScaleSetId int } var runnerScaleSet *RunnerScaleSet - err = unmarshalBody(resp, &runnerScaleSet) + err = json.NewDecoder(resp.Body).Decode(&runnerScaleSet) if err != nil { return nil, err } @@ -263,24 +298,12 @@ func (c *Client) GetRunnerScaleSetById(ctx context.Context, runnerScaleSetId int } func (c *Client) GetRunnerGroupByName(ctx context.Context, runnerGroup string) (*RunnerGroup, error) { - u := fmt.Sprintf("%s/_apis/runtime/runnergroups/?groupName=%s&api-version=6.0-preview", *c.ActionsServiceURL, runnerGroup) - - if err := c.refreshTokenIfNeeded(ctx); err != nil { - return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + path := fmt.Sprintf("/_apis/runtime/runnergroups/?groupName=%s", runnerGroup) + req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken)) - - if c.userAgent != "" { - req.Header.Set("User-Agent", c.userAgent) - } - resp, err := c.Do(req) if err != nil { return nil, err @@ -295,7 +318,7 @@ func (c *Client) GetRunnerGroupByName(ctx context.Context, runnerGroup string) ( } var runnerGroupList *RunnerGroupList - err = unmarshalBody(resp, &runnerGroupList) + err = json.NewDecoder(resp.Body).Decode(&runnerGroupList) if err != nil { return nil, err } @@ -312,29 +335,16 @@ func (c *Client) GetRunnerGroupByName(ctx context.Context, runnerGroup string) ( } func (c *Client) CreateRunnerScaleSet(ctx context.Context, runnerScaleSet *RunnerScaleSet) (*RunnerScaleSet, error) { - u := fmt.Sprintf("%s/%s?api-version=6.0-preview", *c.ActionsServiceURL, scaleSetEndpoint) - - if err := c.refreshTokenIfNeeded(ctx); err != nil { - return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err) - } - body, err := json.Marshal(runnerScaleSet) if err != nil { return nil, err } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, bytes.NewBuffer(body)) + req, err := c.NewActionsServiceRequest(ctx, http.MethodPost, scaleSetEndpoint, bytes.NewReader(body)) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken)) - - if c.userAgent != "" { - req.Header.Set("User-Agent", c.userAgent) - } - resp, err := c.Do(req) if err != nil { return nil, err @@ -344,7 +354,7 @@ func (c *Client) CreateRunnerScaleSet(ctx context.Context, runnerScaleSet *Runne return nil, ParseActionsErrorFromResponse(resp) } var createdRunnerScaleSet *RunnerScaleSet - err = unmarshalBody(resp, &createdRunnerScaleSet) + err = json.NewDecoder(resp.Body).Decode(&createdRunnerScaleSet) if err != nil { return nil, err } @@ -352,29 +362,18 @@ func (c *Client) CreateRunnerScaleSet(ctx context.Context, runnerScaleSet *Runne } func (c *Client) UpdateRunnerScaleSet(ctx context.Context, runnerScaleSetId int, runnerScaleSet *RunnerScaleSet) (*RunnerScaleSet, error) { - u := fmt.Sprintf("%s/%s/%d?api-version=6.0-preview", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId) - - if err := c.refreshTokenIfNeeded(ctx); err != nil { - return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err) - } + path := fmt.Sprintf("%s/%d", scaleSetEndpoint, runnerScaleSetId) body, err := json.Marshal(runnerScaleSet) if err != nil { return nil, err } - req, err := http.NewRequestWithContext(ctx, http.MethodPatch, u, bytes.NewBuffer(body)) + req, err := c.NewActionsServiceRequest(ctx, http.MethodPatch, path, bytes.NewReader(body)) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken)) - - if c.userAgent != "" { - req.Header.Set("User-Agent", c.userAgent) - } - resp, err := c.Do(req) if err != nil { return nil, err @@ -385,7 +384,7 @@ func (c *Client) UpdateRunnerScaleSet(ctx context.Context, runnerScaleSetId int, } var updatedRunnerScaleSet *RunnerScaleSet - err = unmarshalBody(resp, &updatedRunnerScaleSet) + err = json.NewDecoder(resp.Body).Decode(&updatedRunnerScaleSet) if err != nil { return nil, err } @@ -393,24 +392,12 @@ func (c *Client) UpdateRunnerScaleSet(ctx context.Context, runnerScaleSetId int, } func (c *Client) DeleteRunnerScaleSet(ctx context.Context, runnerScaleSetId int) error { - u := fmt.Sprintf("%s/%s/%d?api-version=6.0-preview", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId) - - if err := c.refreshTokenIfNeeded(ctx); err != nil { - return fmt.Errorf("failed to refresh admin token if needed: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, u, nil) + path := fmt.Sprintf("/%s/%d", scaleSetEndpoint, runnerScaleSetId) + req, err := c.NewActionsServiceRequest(ctx, http.MethodDelete, path, nil) if err != nil { return err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken)) - - if c.userAgent != "" { - req.Header.Set("User-Agent", c.userAgent) - } - resp, err := c.Do(req) if err != nil { return err @@ -425,12 +412,18 @@ func (c *Client) DeleteRunnerScaleSet(ctx context.Context, runnerScaleSetId int) } func (c *Client) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64) (*RunnerScaleSetMessage, error) { - u := messageQueueUrl - if lastMessageId > 0 { - u = fmt.Sprintf("%s&lassMessageId=%d", messageQueueUrl, lastMessageId) + u, err := url.Parse(messageQueueUrl) + if err != nil { + return nil, err } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if lastMessageId > 0 { + q := u.Query() + q.Set("lastMessageId", strconv.FormatInt(lastMessageId, 10)) + u.RawQuery = q.Encode() + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) if err != nil { return nil, err } @@ -466,7 +459,7 @@ func (c *Client) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAc } var message *RunnerScaleSetMessage - err = unmarshalBody(resp, &message) + err = json.NewDecoder(resp.Body).Decode(&message) if err != nil { return nil, err } @@ -514,7 +507,7 @@ func (c *Client) DeleteMessage(ctx context.Context, messageQueueUrl, messageQueu } func (c *Client) CreateMessageSession(ctx context.Context, runnerScaleSetId int, owner string) (*RunnerScaleSetSession, error) { - u := fmt.Sprintf("%v/%v/%v/sessions?%v", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId, apiVersionQueryParam) + path := fmt.Sprintf("/%s/%d/sessions", scaleSetEndpoint, runnerScaleSetId) newSession := &RunnerScaleSetSession{ OwnerName: owner, @@ -527,49 +520,36 @@ func (c *Client) CreateMessageSession(ctx context.Context, runnerScaleSetId int, createdSession := &RunnerScaleSetSession{} - err = c.doSessionRequest(ctx, http.MethodPost, u, bytes.NewBuffer(requestData), http.StatusOK, createdSession) + err = c.doSessionRequest(ctx, http.MethodPost, path, bytes.NewBuffer(requestData), http.StatusOK, createdSession) return createdSession, err } func (c *Client) DeleteMessageSession(ctx context.Context, runnerScaleSetId int, sessionId *uuid.UUID) error { - u := fmt.Sprintf("%v/%v/%v/sessions/%v?%v", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId, sessionId.String(), apiVersionQueryParam) - - return c.doSessionRequest(ctx, http.MethodDelete, u, nil, http.StatusNoContent, nil) + path := fmt.Sprintf("/%s/%d/sessions/%s", scaleSetEndpoint, runnerScaleSetId, sessionId.String()) + return c.doSessionRequest(ctx, http.MethodDelete, path, nil, http.StatusNoContent, nil) } func (c *Client) RefreshMessageSession(ctx context.Context, runnerScaleSetId int, sessionId *uuid.UUID) (*RunnerScaleSetSession, error) { - u := fmt.Sprintf("%v/%v/%v/sessions/%v?%v", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId, sessionId.String(), apiVersionQueryParam) + path := fmt.Sprintf("/%s/%d/sessions/%s", scaleSetEndpoint, runnerScaleSetId, sessionId.String()) refreshedSession := &RunnerScaleSetSession{} - err := c.doSessionRequest(ctx, http.MethodPatch, u, nil, http.StatusOK, refreshedSession) + err := c.doSessionRequest(ctx, http.MethodPatch, path, nil, http.StatusOK, refreshedSession) return refreshedSession, err } -func (c *Client) doSessionRequest(ctx context.Context, method, url string, requestData io.Reader, expectedResponseStatusCode int, responseUnmarshalTarget any) error { - if err := c.refreshTokenIfNeeded(ctx); err != nil { - return fmt.Errorf("failed to refresh admin token if needed: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, method, url, requestData) +func (c *Client) doSessionRequest(ctx context.Context, method, path string, requestData io.Reader, expectedResponseStatusCode int, responseUnmarshalTarget any) error { + req, err := c.NewActionsServiceRequest(ctx, method, path, requestData) if err != nil { return err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken)) - - if c.userAgent != "" { - req.Header.Set("User-Agent", c.userAgent) - } - resp, err := c.Do(req) if err != nil { return err } if resp.StatusCode == expectedResponseStatusCode && responseUnmarshalTarget != nil { - err = unmarshalBody(resp, &responseUnmarshalTarget) - return err + return json.NewDecoder(resp.Body).Decode(responseUnmarshalTarget) } if resp.StatusCode >= 400 && resp.StatusCode < 500 { @@ -587,7 +567,7 @@ func (c *Client) doSessionRequest(ctx context.Context, method, url string, reque } func (c *Client) AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQueueAccessToken string, requestIds []int64) ([]int64, error) { - u := fmt.Sprintf("%s/%s/%d/acquirejobs?api-version=6.0-preview", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId) + u := fmt.Sprintf("%s/%s/%d/acquirejobs?api-version=6.0-preview", c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId) body, err := json.Marshal(requestIds) if err != nil { @@ -614,8 +594,8 @@ func (c *Client) AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQ return nil, ParseActionsErrorFromResponse(resp) } - var acquiredJobs Int64List - err = unmarshalBody(resp, &acquiredJobs) + var acquiredJobs *Int64List + err = json.NewDecoder(resp.Body).Decode(&acquiredJobs) if err != nil { return nil, err } @@ -624,24 +604,13 @@ func (c *Client) AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQ } func (c *Client) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (*AcquirableJobList, error) { - u := fmt.Sprintf("%s/%s/%d/acquirablejobs?api-version=6.0-preview", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId) + path := fmt.Sprintf("/%s/%d/acquirablejobs", scaleSetEndpoint, runnerScaleSetId) - if err := c.refreshTokenIfNeeded(ctx); err != nil { - return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken)) - - if c.userAgent != "" { - req.Header.Set("User-Agent", c.userAgent) - } - resp, err := c.Do(req) if err != nil { return nil, err @@ -657,7 +626,7 @@ func (c *Client) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (* } var acquirableJobList *AcquirableJobList - err = unmarshalBody(resp, &acquirableJobList) + err = json.NewDecoder(resp.Body).Decode(&acquirableJobList) if err != nil { return nil, err } @@ -666,28 +635,18 @@ func (c *Client) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (* } func (c *Client) GenerateJitRunnerConfig(ctx context.Context, jitRunnerSetting *RunnerScaleSetJitRunnerSetting, scaleSetId int) (*RunnerScaleSetJitRunnerConfig, error) { - runnerJitConfigUrl := fmt.Sprintf("%s/%s/%d/generatejitconfig?api-version=6.0-preview", *c.ActionsServiceURL, scaleSetEndpoint, scaleSetId) - - if err := c.refreshTokenIfNeeded(ctx); err != nil { - return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err) - } + path := fmt.Sprintf("/%s/%d/generatejitconfig", scaleSetEndpoint, scaleSetId) body, err := json.Marshal(jitRunnerSetting) if err != nil { return nil, err } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, runnerJitConfigUrl, bytes.NewBuffer(body)) + req, err := c.NewActionsServiceRequest(ctx, http.MethodPost, path, bytes.NewBuffer(body)) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken)) - if c.userAgent != "" { - req.Header.Set("User-Agent", c.userAgent) - } - resp, err := c.Do(req) if err != nil { return nil, err @@ -698,7 +657,7 @@ func (c *Client) GenerateJitRunnerConfig(ctx context.Context, jitRunnerSetting * } var runnerJitConfig *RunnerScaleSetJitRunnerConfig - err = unmarshalBody(resp, &runnerJitConfig) + err = json.NewDecoder(resp.Body).Decode(&runnerJitConfig) if err != nil { return nil, err } @@ -706,24 +665,13 @@ func (c *Client) GenerateJitRunnerConfig(ctx context.Context, jitRunnerSetting * } func (c *Client) GetRunner(ctx context.Context, runnerId int64) (*RunnerReference, error) { - url := fmt.Sprintf("%v/%v/%v?%v", *c.ActionsServiceURL, runnerEndpoint, runnerId, apiVersionQueryParam) + path := fmt.Sprintf("/%s/%d", runnerEndpoint, runnerId) - if err := c.refreshTokenIfNeeded(ctx); err != nil { - return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken)) - - if c.userAgent != "" { - req.Header.Set("User-Agent", c.userAgent) - } - resp, err := c.Do(req) if err != nil { return nil, err @@ -734,7 +682,8 @@ func (c *Client) GetRunner(ctx context.Context, runnerId int64) (*RunnerReferenc } var runnerReference *RunnerReference - if err := unmarshalBody(resp, &runnerReference); err != nil { + err = json.NewDecoder(resp.Body).Decode(&runnerReference) + if err != nil { return nil, err } @@ -742,24 +691,13 @@ func (c *Client) GetRunner(ctx context.Context, runnerId int64) (*RunnerReferenc } func (c *Client) GetRunnerByName(ctx context.Context, runnerName string) (*RunnerReference, error) { - url := fmt.Sprintf("%v/%v?agentName=%v&%v", *c.ActionsServiceURL, runnerEndpoint, runnerName, apiVersionQueryParam) + path := fmt.Sprintf("/%s?agentName=%s", runnerEndpoint, runnerName) - if err := c.refreshTokenIfNeeded(ctx); err != nil { - return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken)) - - if c.userAgent != "" { - req.Header.Set("User-Agent", c.userAgent) - } - resp, err := c.Do(req) if err != nil { return nil, err @@ -770,7 +708,7 @@ func (c *Client) GetRunnerByName(ctx context.Context, runnerName string) (*Runne } var runnerList *RunnerReferenceList - err = unmarshalBody(resp, &runnerList) + err = json.NewDecoder(resp.Body).Decode(&runnerList) if err != nil { return nil, err } @@ -787,24 +725,13 @@ func (c *Client) GetRunnerByName(ctx context.Context, runnerName string) (*Runne } func (c *Client) RemoveRunner(ctx context.Context, runnerId int64) error { - url := fmt.Sprintf("%v/%v/%v?%v", *c.ActionsServiceURL, runnerEndpoint, runnerId, apiVersionQueryParam) + path := fmt.Sprintf("/%s/%d", runnerEndpoint, runnerId) - if err := c.refreshTokenIfNeeded(ctx); err != nil { - return fmt.Errorf("failed to refresh admin token if needed: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil) + req, err := c.NewActionsServiceRequest(ctx, http.MethodDelete, path, nil) if err != nil { return err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken)) - - if c.userAgent != "" { - req.Header.Set("User-Agent", c.userAgent) - } - resp, err := c.Do(req) if err != nil { return err @@ -823,25 +750,25 @@ type registrationToken struct { ExpiresAt *time.Time `json:"expires_at,omitempty"` } -func (c *Client) getRunnerRegistrationToken(ctx context.Context, githubConfigUrl string, creds ActionsAuth) (*registrationToken, error) { - registrationTokenURL, err := createRegistrationTokenURL(githubConfigUrl) +func (c *Client) getRunnerRegistrationToken(ctx context.Context) (*registrationToken, error) { + path, err := createRegistrationTokenPath(c.config) if err != nil { return nil, err } var buf bytes.Buffer - req, err := http.NewRequestWithContext(ctx, http.MethodPost, registrationTokenURL, &buf) + req, err := c.NewGitHubAPIRequest(ctx, http.MethodPost, path, &buf) if err != nil { return nil, err } bearerToken := "" - if creds.Token != "" { - encodedToken := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("github:%v", creds.Token))) + if c.creds.Token != "" { + encodedToken := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("github:%v", c.creds.Token))) bearerToken = fmt.Sprintf("Basic %v", encodedToken) } else { - accessToken, err := c.fetchAccessToken(ctx, githubConfigUrl, creds.AppCreds) + accessToken, err := c.fetchAccessToken(ctx, c.config.ConfigURL.String(), c.creds.AppCreds) if err != nil { return nil, err } @@ -851,9 +778,8 @@ func (c *Client) getRunnerRegistrationToken(ctx context.Context, githubConfigUrl req.Header.Set("Content-Type", "application/vnd.github.v3+json") req.Header.Set("Authorization", bearerToken) - req.Header.Set("User-Agent", c.userAgent) - c.logger.Info("getting runner registration token", "registrationTokenURL", registrationTokenURL) + c.logger.Info("getting runner registration token", "registrationTokenURL", req.URL.String()) resp, err := c.Do(req) if err != nil { @@ -869,8 +795,8 @@ func (c *Client) getRunnerRegistrationToken(ctx context.Context, githubConfigUrl return nil, fmt.Errorf("unexpected response from Actions service during registration token call: %v - %v", resp.StatusCode, string(body)) } - registrationToken := ®istrationToken{} - if err := json.NewDecoder(resp.Body).Decode(registrationToken); err != nil { + var registrationToken *registrationToken + if err := json.NewDecoder(resp.Body).Decode(®istrationToken); err != nil { return nil, err } @@ -889,21 +815,16 @@ func (c *Client) fetchAccessToken(ctx context.Context, gitHubConfigURL string, c return nil, err } - u, err := githubAPIURL(gitHubConfigURL, fmt.Sprintf("/app/installations/%v/access_tokens", creds.AppInstallationID)) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, nil) + path := fmt.Sprintf("/app/installations/%v/access_tokens", creds.AppInstallationID) + req, err := c.NewGitHubAPIRequest(ctx, http.MethodPost, path, nil) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/vnd.github+json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessTokenJWT)) - req.Header.Add("User-Agent", c.userAgent) - c.logger.Info("getting access token for GitHub App auth", "accessTokenURL", u) + c.logger.Info("getting access token for GitHub App auth", "accessTokenURL", req.URL.String()) resp, err := c.Do(req) if err != nil { @@ -912,8 +833,8 @@ func (c *Client) fetchAccessToken(ctx context.Context, gitHubConfigURL string, c defer resp.Body.Close() // Format: https://docs.github.com/en/rest/apps/apps#create-an-installation-access-token-for-an-app - accessToken := &accessToken{} - err = json.NewDecoder(resp.Body).Decode(accessToken) + var accessToken *accessToken + err = json.NewDecoder(resp.Body).Decode(&accessToken) return accessToken, err } @@ -922,27 +843,14 @@ type ActionsServiceAdminConnection struct { AdminToken *string `json:"token,omitempty"` } -func (c *Client) getActionsServiceAdminConnection(ctx context.Context, rt *registrationToken, githubConfigUrl string) (*ActionsServiceAdminConnection, error) { - parsedGitHubConfigURL, err := url.Parse(githubConfigUrl) - if err != nil { - return nil, err - } - - if isHostedServer(*parsedGitHubConfigURL) { - parsedGitHubConfigURL.Host = fmt.Sprintf("api.%v", parsedGitHubConfigURL.Host) - } - - ru := fmt.Sprintf("%v://%v/actions/runner-registration", parsedGitHubConfigURL.Scheme, parsedGitHubConfigURL.Host) - registrationURL, err := url.Parse(ru) - if err != nil { - return nil, err - } +func (c *Client) getActionsServiceAdminConnection(ctx context.Context, rt *registrationToken) (*ActionsServiceAdminConnection, error) { + path := "/actions/runner-registration" body := struct { Url string `json:"url"` RunnerEvent string `json:"runner_event"` }{ - Url: githubConfigUrl, + Url: c.config.ConfigURL.String(), RunnerEvent: "register", } @@ -954,16 +862,15 @@ func (c *Client) getActionsServiceAdminConnection(ctx context.Context, rt *regis return nil, err } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, registrationURL.String(), buf) + req, err := c.NewGitHubAPIRequest(ctx, http.MethodPost, path, buf) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("RemoteAuth %s", *rt.Token)) - req.Header.Set("User-Agent", c.userAgent) - c.logger.Info("getting Actions tenant URL and JWT", "registrationURL", registrationURL.String()) + c.logger.Info("getting Actions tenant URL and JWT", "registrationURL", req.URL.String()) resp, err := c.Do(req) if err != nil { @@ -971,65 +878,30 @@ func (c *Client) getActionsServiceAdminConnection(ctx context.Context, rt *regis } defer resp.Body.Close() - actionsServiceAdminConnection := &ActionsServiceAdminConnection{} - if err := json.NewDecoder(resp.Body).Decode(actionsServiceAdminConnection); err != nil { + var actionsServiceAdminConnection *ActionsServiceAdminConnection + if err := json.NewDecoder(resp.Body).Decode(&actionsServiceAdminConnection); err != nil { return nil, err } return actionsServiceAdminConnection, nil } -func isHostedServer(gitHubURL url.URL) bool { - return gitHubURL.Host == "github.com" || - gitHubURL.Host == "www.github.com" || - gitHubURL.Host == "github.localhost" -} +func createRegistrationTokenPath(config *GitHubConfig) (string, error) { + switch config.Scope { + case GitHubScopeOrganization: + path := fmt.Sprintf("/orgs/%s/actions/runners/registration-token", config.Organization) + return path, nil -func createRegistrationTokenURL(githubConfigUrl string) (string, error) { - parsedGitHubConfigURL, err := url.Parse(githubConfigUrl) - if err != nil { - return "", err - } + case GitHubScopeEnterprise: + path := fmt.Sprintf("/enterprises/%s/actions/runners/registration-token", config.Enterprise) + return path, nil - // Check for empty path before split, because strings.Split will return a slice of length 1 - // when the split delimiter is not present. - trimmedPath := strings.TrimLeft(parsedGitHubConfigURL.Path, "/") - if len(trimmedPath) == 0 { - return "", fmt.Errorf("%q should point to an enterprise, org, or repository", parsedGitHubConfigURL.String()) - } + case GitHubScopeRepository: + path := fmt.Sprintf("/repos/%s/%s/actions/runners/registration-token", config.Organization, config.Repository) + return path, nil - pathParts := strings.Split(path.Clean(strings.TrimLeft(parsedGitHubConfigURL.Path, "/")), "/") - - switch len(pathParts) { - case 1: // Organization - registrationTokenURL := fmt.Sprintf( - "%v://%v/api/v3/orgs/%v/actions/runners/registration-token", - parsedGitHubConfigURL.Scheme, parsedGitHubConfigURL.Host, pathParts[0]) - - if isHostedServer(*parsedGitHubConfigURL) { - registrationTokenURL = fmt.Sprintf( - "%v://api.%v/orgs/%v/actions/runners/registration-token", - parsedGitHubConfigURL.Scheme, parsedGitHubConfigURL.Host, pathParts[0]) - } - - return registrationTokenURL, nil - case 2: // Repository or enterprise - repoScope := "repos/" - if strings.ToLower(pathParts[0]) == "enterprises" { - repoScope = "" - } - - registrationTokenURL := fmt.Sprintf("%v://%v/api/v3/%v%v/%v/actions/runners/registration-token", - parsedGitHubConfigURL.Scheme, parsedGitHubConfigURL.Host, repoScope, pathParts[0], pathParts[1]) - - if isHostedServer(*parsedGitHubConfigURL) { - registrationTokenURL = fmt.Sprintf("%v://api.%v/%v%v/%v/actions/runners/registration-token", - parsedGitHubConfigURL.Scheme, parsedGitHubConfigURL.Host, repoScope, pathParts[0], pathParts[1]) - } - - return registrationTokenURL, nil default: - return "", fmt.Errorf("%q should point to an enterprise, org, or repository", parsedGitHubConfigURL.String()) + return "", fmt.Errorf("unknown scope for config url: %s", config.ConfigURL) } } @@ -1057,68 +929,50 @@ func createJWTForGitHubApp(appAuth *GitHubAppAuth) (string, error) { return token.SignedString(privateKey) } -func unmarshalBody(response *http.Response, v interface{}) (err error) { - if response != nil && response.Body != nil { - var err error - defer func() { - if closeError := response.Body.Close(); closeError != nil { - err = closeError - } - }() - body, err := io.ReadAll(response.Body) - if err != nil { - return err - } - body = trimByteOrderMark(body) - return json.Unmarshal(body, &v) - } - return nil -} - // Returns slice of body without utf-8 byte order mark. // If BOM does not exist body is returned unchanged. func trimByteOrderMark(body []byte) []byte { return bytes.TrimPrefix(body, []byte("\xef\xbb\xbf")) } -func actionsServiceAdminTokenExpiresAt(jwtToken string) (*time.Time, error) { +func actionsServiceAdminTokenExpiresAt(jwtToken string) (time.Time, error) { type JwtClaims struct { jwt.RegisteredClaims } token, _, err := jwt.NewParser().ParseUnverified(jwtToken, &JwtClaims{}) if err != nil { - return nil, fmt.Errorf("failed to parse jwt token: %w", err) + return time.Time{}, fmt.Errorf("failed to parse jwt token: %w", err) } if claims, ok := token.Claims.(*JwtClaims); ok { - return &claims.ExpiresAt.Time, nil + return claims.ExpiresAt.Time, nil } - return nil, fmt.Errorf("failed to parse token claims to get expire at") + return time.Time{}, fmt.Errorf("failed to parse token claims to get expire at") } -func (c *Client) refreshTokenIfNeeded(ctx context.Context) error { +func (c *Client) updateTokenIfNeeded(ctx context.Context) error { c.mu.Lock() defer c.mu.Unlock() - aboutToExpire := time.Now().Add(60 * time.Second).After(*c.ActionsServiceAdminTokenExpiresAt) - if !aboutToExpire { + aboutToExpire := time.Now().Add(60 * time.Second).After(c.ActionsServiceAdminTokenExpiresAt) + if !aboutToExpire && !c.ActionsServiceAdminTokenExpiresAt.IsZero() { return nil } - c.logger.Info("Admin token is about to expire, refreshing it", "githubConfigUrl", c.githubConfigURL) - rt, err := c.getRunnerRegistrationToken(ctx, c.githubConfigURL, *c.creds) + c.logger.Info("refreshing token", "githubConfigUrl", c.config.ConfigURL.String()) + rt, err := c.getRunnerRegistrationToken(ctx) if err != nil { - return fmt.Errorf("failed to get runner registration token on fresh: %w", err) + return fmt.Errorf("failed to get runner registration token on refresh: %w", err) } - adminConnInfo, err := c.getActionsServiceAdminConnection(ctx, rt, c.githubConfigURL) + adminConnInfo, err := c.getActionsServiceAdminConnection(ctx, rt) if err != nil { - return fmt.Errorf("failed to get actions service admin connection on fresh: %w", err) + return fmt.Errorf("failed to get actions service admin connection on refresh: %w", err) } - c.ActionsServiceURL = adminConnInfo.ActionsServiceUrl - c.ActionsServiceAdminToken = adminConnInfo.AdminToken + c.ActionsServiceURL = *adminConnInfo.ActionsServiceUrl + c.ActionsServiceAdminToken = *adminConnInfo.AdminToken c.ActionsServiceAdminTokenExpiresAt, err = actionsServiceAdminTokenExpiresAt(*adminConnInfo.AdminToken) if err != nil { return fmt.Errorf("failed to get admin token expire at on refresh: %w", err) @@ -1126,32 +980,3 @@ func (c *Client) refreshTokenIfNeeded(ctx context.Context) error { return nil } - -func githubAPIURL(configURL, path string) (string, error) { - u, err := url.Parse(configURL) - if err != nil { - return "", err - } - - result := &url.URL{ - Scheme: u.Scheme, - } - - switch u.Host { - // Hosted - case "github.com", "github.localhost": - result.Host = fmt.Sprintf("api.%s", u.Host) - // re-routing www.github.com to api.github.com - case "www.github.com": - result.Host = "api.github.com" - - // Enterprise - default: - result.Host = u.Host - result.Path = "/api/v3" - } - - result.Path += path - - return result.String(), nil -} diff --git a/github/actions/client_generate_jit_test.go b/github/actions/client_generate_jit_test.go index cf594151..94f9d537 100644 --- a/github/actions/client_generate_jit_test.go +++ b/github/actions/client_generate_jit_test.go @@ -26,7 +26,7 @@ func TestGenerateJitRunnerConfig(t *testing.T) { server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Write(response) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) got, err := client.GenerateJitRunnerConfig(ctx, runnerSettings, 1) @@ -47,7 +47,6 @@ func TestGenerateJitRunnerConfig(t *testing.T) { })) client, err := actions.NewClient( - ctx, server.configURLForOrg("my-org"), auth, actions.WithRetryMax(1), diff --git a/github/actions/client_job_acquisition_test.go b/github/actions/client_job_acquisition_test.go index dfd0d58d..e11b7ba2 100644 --- a/github/actions/client_job_acquisition_test.go +++ b/github/actions/client_job_acquisition_test.go @@ -3,6 +3,7 @@ package actions_test import ( "context" "net/http" + "strings" "testing" "time" @@ -27,11 +28,19 @@ func TestAcquireJobs(t *testing.T) { } requestIDs := want - server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/acquirablejobs") { + w.Write([]byte(`{"count": 1}`)) + return + } + w.Write(response) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) + require.NoError(t, err) + + _, err = client.GetAcquirableJobs(ctx, 1) require.NoError(t, err) got, err := client.AcquireJobs(ctx, session.RunnerScaleSet.Id, session.MessageQueueAccessToken, requestIDs) @@ -50,13 +59,17 @@ func TestAcquireJobs(t *testing.T) { actualRetry := 0 expectedRetry := retryMax + 1 - server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/acquirablejobs") { + w.Write([]byte(`{"count": 1}`)) + return + } + w.WriteHeader(http.StatusServiceUnavailable) actualRetry++ })) client, err := actions.NewClient( - ctx, server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), @@ -64,6 +77,9 @@ func TestAcquireJobs(t *testing.T) { ) require.NoError(t, err) + _, err = client.GetAcquirableJobs(ctx, 1) + require.NoError(t, err) + _, err = client.AcquireJobs(context.Background(), session.RunnerScaleSet.Id, session.MessageQueueAccessToken, requestIDs) assert.NotNil(t, err) assert.Equalf(t, actualRetry, expectedRetry, "A retry was expected after the first request but got: %v", actualRetry) @@ -71,7 +87,6 @@ func TestAcquireJobs(t *testing.T) { } func TestGetAcquirableJobs(t *testing.T) { - ctx := context.Background() auth := &actions.ActionsAuth{ Token: "token", } @@ -86,7 +101,7 @@ func TestGetAcquirableJobs(t *testing.T) { w.Write(response) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) got, err := client.GetAcquirableJobs(context.Background(), runnerScaleSet.Id) @@ -108,7 +123,6 @@ func TestGetAcquirableJobs(t *testing.T) { })) client, err := actions.NewClient( - context.Background(), server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), diff --git a/github/actions/client_runner_scale_set_message_test.go b/github/actions/client_runner_scale_set_message_test.go index 55e80267..0de77094 100644 --- a/github/actions/client_runner_scale_set_message_test.go +++ b/github/actions/client_runner_scale_set_message_test.go @@ -32,7 +32,7 @@ func TestGetMessage(t *testing.T) { w.Write(response) })) - client, err := actions.NewClient(ctx, s.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(s.configURLForOrg("my-org"), auth) require.NoError(t, err) got, err := client.GetMessage(ctx, s.URL, token, 0) @@ -40,6 +40,23 @@ func TestGetMessage(t *testing.T) { assert.Equal(t, want, got) }) + t.Run("GetMessage sets the last message id if not 0", func(t *testing.T) { + want := runnerScaleSetMessage + response := []byte(`{"messageId":1,"messageType":"rssType"}`) + s := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + assert.Equal(t, "1", q.Get("lastMessageId")) + w.Write(response) + })) + + client, err := actions.NewClient(s.configURLForOrg("my-org"), auth) + require.NoError(t, err) + + got, err := client.GetMessage(ctx, s.URL, token, 1) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + t.Run("Default retries on server error", func(t *testing.T) { retryMax := 1 @@ -52,7 +69,6 @@ func TestGetMessage(t *testing.T) { })) client, err := actions.NewClient( - ctx, server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), @@ -70,7 +86,7 @@ func TestGetMessage(t *testing.T) { w.WriteHeader(http.StatusUnauthorized) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) _, err = client.GetMessage(ctx, server.URL, token, 0) @@ -78,8 +94,7 @@ func TestGetMessage(t *testing.T) { var expectedErr *actions.MessageQueueTokenExpiredError require.True(t, errors.As(err, &expectedErr)) - }, - ) + }) t.Run("Status code not found", func(t *testing.T) { want := actions.ActionsError{ @@ -90,7 +105,7 @@ func TestGetMessage(t *testing.T) { w.WriteHeader(http.StatusNotFound) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) _, err = client.GetMessage(ctx, server.URL, token, 0) @@ -104,7 +119,7 @@ func TestGetMessage(t *testing.T) { w.Header().Set("Content-Type", "text/plain") })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) _, err = client.GetMessage(ctx, server.URL, token, 0) @@ -129,7 +144,7 @@ func TestDeleteMessage(t *testing.T) { w.WriteHeader(http.StatusNoContent) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) err = client.DeleteMessage(ctx, server.URL, token, runnerScaleSetMessage.MessageId) @@ -141,7 +156,7 @@ func TestDeleteMessage(t *testing.T) { w.WriteHeader(http.StatusUnauthorized) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) err = client.DeleteMessage(ctx, server.URL, token, 0) @@ -156,7 +171,7 @@ func TestDeleteMessage(t *testing.T) { w.Header().Set("Content-Type", "text/plain") })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) err = client.DeleteMessage(ctx, server.URL, token, runnerScaleSetMessage.MessageId) @@ -175,7 +190,6 @@ func TestDeleteMessage(t *testing.T) { retryMax := 1 client, err := actions.NewClient( - ctx, server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), @@ -197,7 +211,7 @@ func TestDeleteMessage(t *testing.T) { w.Write(rsl) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) err = client.DeleteMessage(ctx, server.URL, token, runnerScaleSetMessage.MessageId+1) diff --git a/github/actions/client_runner_scale_set_session_test.go b/github/actions/client_runner_scale_set_session_test.go index f5fbceb7..7b2ab69d 100644 --- a/github/actions/client_runner_scale_set_session_test.go +++ b/github/actions/client_runner_scale_set_session_test.go @@ -51,7 +51,7 @@ func TestCreateMessageSession(t *testing.T) { w.Write(resp) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) got, err := client.CreateMessageSession(ctx, runnerScaleSet.Id, owner) @@ -81,7 +81,7 @@ func TestCreateMessageSession(t *testing.T) { w.Write(resp) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) _, err = client.CreateMessageSession(ctx, runnerScaleSet.Id, owner) @@ -120,7 +120,6 @@ func TestCreateMessageSession(t *testing.T) { wantRetries := retryMax + 1 client, err := actions.NewClient( - ctx, server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), @@ -160,7 +159,6 @@ func TestDeleteMessageSession(t *testing.T) { wantRetries := retryMax + 1 client, err := actions.NewClient( - ctx, server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), @@ -177,7 +175,6 @@ func TestDeleteMessageSession(t *testing.T) { } func TestRefreshMessageSession(t *testing.T) { - ctx := context.Background() auth := &actions.ActionsAuth{ Token: "token", } @@ -202,7 +199,6 @@ func TestRefreshMessageSession(t *testing.T) { wantRetries := retryMax + 1 client, err := actions.NewClient( - ctx, server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), diff --git a/github/actions/client_runner_scale_set_test.go b/github/actions/client_runner_scale_set_test.go index 980b9846..5354a0e5 100644 --- a/github/actions/client_runner_scale_set_test.go +++ b/github/actions/client_runner_scale_set_test.go @@ -31,7 +31,7 @@ func TestGetRunnerScaleSet(t *testing.T) { w.Write(runnerScaleSetsResp) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) got, err := client.GetRunnerScaleSet(ctx, scaleSetName) @@ -47,15 +47,16 @@ func TestGetRunnerScaleSet(t *testing.T) { url = *r.URL })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) _, err = client.GetRunnerScaleSet(ctx, scaleSetName) require.NoError(t, err) - u := url.String() - expectedUrl := fmt.Sprintf("/_apis/runtime/runnerscalesets?name=%s&api-version=6.0-preview", scaleSetName) - assert.Equal(t, expectedUrl, u) + expectedPath := "/tenant/123/_apis/runtime/runnerscalesets" + assert.Equal(t, expectedPath, url.Path) + assert.Equal(t, scaleSetName, url.Query().Get("name")) + assert.Equal(t, "6.0-preview", url.Query().Get("api-version")) }) t.Run("Status code not found", func(t *testing.T) { @@ -63,7 +64,7 @@ func TestGetRunnerScaleSet(t *testing.T) { w.WriteHeader(http.StatusNotFound) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) _, err = client.GetRunnerScaleSet(ctx, scaleSetName) @@ -76,7 +77,7 @@ func TestGetRunnerScaleSet(t *testing.T) { w.Header().Set("Content-Type", "text/plain") })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) _, err = client.GetRunnerScaleSet(ctx, scaleSetName) @@ -94,7 +95,6 @@ func TestGetRunnerScaleSet(t *testing.T) { retryWaitMax := 1 * time.Microsecond client, err := actions.NewClient( - ctx, server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), @@ -115,7 +115,7 @@ func TestGetRunnerScaleSet(t *testing.T) { w.Write(runnerScaleSetsResp) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) got, err := client.GetRunnerScaleSet(ctx, scaleSetName) @@ -130,7 +130,7 @@ func TestGetRunnerScaleSet(t *testing.T) { w.Write(runnerScaleSetsResp) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) _, err = client.GetRunnerScaleSet(ctx, scaleSetName) @@ -156,7 +156,7 @@ func TestGetRunnerScaleSetById(t *testing.T) { w.Write(rsl) })) - client, err := actions.NewClient(ctx, sservere.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(sservere.configURLForOrg("my-org"), auth) require.NoError(t, err) got, err := client.GetRunnerScaleSetById(ctx, runnerScaleSet.Id) @@ -174,15 +174,15 @@ func TestGetRunnerScaleSetById(t *testing.T) { url = *r.URL })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) _, err = client.GetRunnerScaleSetById(ctx, runnerScaleSet.Id) require.NoError(t, err) - u := url.String() - expectedUrl := fmt.Sprintf("/_apis/runtime/runnerscalesets/%d?api-version=6.0-preview", runnerScaleSet.Id) - assert.Equal(t, expectedUrl, u) + expectedPath := fmt.Sprintf("/tenant/123/_apis/runtime/runnerscalesets/%d", runnerScaleSet.Id) + assert.Equal(t, expectedPath, url.Path) + assert.Equal(t, "6.0-preview", url.Query().Get("api-version")) }) t.Run("Status code not found", func(t *testing.T) { @@ -190,7 +190,7 @@ func TestGetRunnerScaleSetById(t *testing.T) { w.WriteHeader(http.StatusNotFound) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) _, err = client.GetRunnerScaleSetById(ctx, runnerScaleSet.Id) @@ -203,7 +203,7 @@ func TestGetRunnerScaleSetById(t *testing.T) { w.Header().Set("Content-Type", "text/plain") })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) _, err = client.GetRunnerScaleSetById(ctx, runnerScaleSet.Id) @@ -220,7 +220,6 @@ func TestGetRunnerScaleSetById(t *testing.T) { retryMax := 1 retryWaitMax := 1 * time.Microsecond client, err := actions.NewClient( - ctx, server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), @@ -242,7 +241,7 @@ func TestGetRunnerScaleSetById(t *testing.T) { w.Write(rsl) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) got, err := client.GetRunnerScaleSetById(ctx, runnerScaleSet.Id) @@ -268,7 +267,7 @@ func TestCreateRunnerScaleSet(t *testing.T) { w.Write(rsl) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) got, err := client.CreateRunnerScaleSet(ctx, &runnerScaleSet) @@ -285,15 +284,15 @@ func TestCreateRunnerScaleSet(t *testing.T) { url = *r.URL })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) _, err = client.CreateRunnerScaleSet(ctx, &runnerScaleSet) require.NoError(t, err) - u := url.String() - expectedUrl := "/_apis/runtime/runnerscalesets?api-version=6.0-preview" - assert.Equal(t, expectedUrl, u) + expectedPath := "/tenant/123/_apis/runtime/runnerscalesets" + assert.Equal(t, expectedPath, url.Path) + assert.Equal(t, "6.0-preview", url.Query().Get("api-version")) }) t.Run("Error when Content-Type is text/plain", func(t *testing.T) { @@ -302,7 +301,7 @@ func TestCreateRunnerScaleSet(t *testing.T) { w.Header().Set("Content-Type", "text/plain") })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) _, err = client.CreateRunnerScaleSet(ctx, &runnerScaleSet) @@ -322,7 +321,6 @@ func TestCreateRunnerScaleSet(t *testing.T) { retryWaitMax := 1 * time.Microsecond client, err := actions.NewClient( - ctx, server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), @@ -354,7 +352,7 @@ func TestUpdateRunnerScaleSet(t *testing.T) { w.Write(rsl) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) got, err := client.UpdateRunnerScaleSet(ctx, 1, &actions.RunnerScaleSet{RunnerGroupId: 1}) @@ -365,24 +363,19 @@ func TestUpdateRunnerScaleSet(t *testing.T) { t.Run("UpdateRunnerScaleSet calls correct url", func(t *testing.T) { rsl, err := json.Marshal(&runnerScaleSet) require.NoError(t, err) - url := url.URL{} - method := "" server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectedPath := "/tenant/123/_apis/runtime/runnerscalesets/1" + assert.Equal(t, expectedPath, r.URL.Path) + assert.Equal(t, http.MethodPatch, r.Method) + assert.Equal(t, "6.0-preview", r.URL.Query().Get("api-version")) + w.Write(rsl) - url = *r.URL - method = r.Method })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) _, err = client.UpdateRunnerScaleSet(ctx, 1, &runnerScaleSet) require.NoError(t, err) - - u := url.String() - expectedUrl := "/_apis/runtime/runnerscalesets/1?api-version=6.0-preview" - assert.Equal(t, expectedUrl, u) - - assert.Equal(t, "PATCH", method) }) } diff --git a/github/actions/client_runner_test.go b/github/actions/client_runner_test.go index 38d7b298..1ad4947e 100644 --- a/github/actions/client_runner_test.go +++ b/github/actions/client_runner_test.go @@ -29,7 +29,7 @@ func TestGetRunner(t *testing.T) { w.Write(response) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) got, err := client.GetRunner(ctx, runnerID) @@ -50,7 +50,7 @@ func TestGetRunner(t *testing.T) { actualRetry++ })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), actions.WithRetryWaitMax(retryWaitMax)) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), actions.WithRetryWaitMax(retryWaitMax)) require.NoError(t, err) _, err = client.GetRunner(ctx, runnerID) @@ -78,7 +78,7 @@ func TestGetRunnerByName(t *testing.T) { w.Write(response) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) got, err := client.GetRunnerByName(ctx, runnerName) @@ -94,7 +94,7 @@ func TestGetRunnerByName(t *testing.T) { w.Write(response) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) got, err := client.GetRunnerByName(ctx, runnerName) @@ -116,7 +116,7 @@ func TestGetRunnerByName(t *testing.T) { actualRetry++ })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), actions.WithRetryWaitMax(retryWaitMax)) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), actions.WithRetryWaitMax(retryWaitMax)) require.NoError(t, err) _, err = client.GetRunnerByName(ctx, runnerName) @@ -138,7 +138,7 @@ func TestDeleteRunner(t *testing.T) { w.WriteHeader(http.StatusNoContent) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) err = client.RemoveRunner(ctx, runnerID) @@ -160,7 +160,6 @@ func TestDeleteRunner(t *testing.T) { })) client, err := actions.NewClient( - ctx, server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), @@ -193,7 +192,7 @@ func TestGetRunnerGroupByName(t *testing.T) { w.Write(response) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) got, err := client.GetRunnerGroupByName(ctx, runnerGroupName) @@ -209,7 +208,7 @@ func TestGetRunnerGroupByName(t *testing.T) { w.Write(response) })) - client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth) + client, err := actions.NewClient(server.configURLForOrg("my-org"), auth) require.NoError(t, err) got, err := client.GetRunnerGroupByName(ctx, runnerGroupName) diff --git a/github/actions/client_tls_test.go b/github/actions/client_tls_test.go index 320798b8..5e7190b5 100644 --- a/github/actions/client_tls_test.go +++ b/github/actions/client_tls_test.go @@ -22,9 +22,9 @@ import ( func TestServerWithSelfSignedCertificates(t *testing.T) { ctx := context.Background() - // this handler is a very very barebones replica of actions api // used during the creation of a a new client + var u string h := func(w http.ResponseWriter, r *http.Request) { // handle get registration token if strings.HasSuffix(r.URL.Path, "/runners/registration-token") { @@ -46,9 +46,12 @@ func TestServerWithSelfSignedCertificates(t *testing.T) { require.NoError(t, err) tokenString, err := token.SignedString(privateKey) require.NoError(t, err) - w.Write([]byte(`{"url":"TODO","token":"` + tokenString + `"}`)) + w.Write([]byte(`{"url":"` + u + `","token":"` + tokenString + `"}`)) return } + + // default happy response for RemoveRunner + w.WriteHeader(http.StatusNoContent) } certPath := filepath.Join("testdata", "server.crt") @@ -56,13 +59,17 @@ func TestServerWithSelfSignedCertificates(t *testing.T) { t.Run("client without ca certs", func(t *testing.T) { server := startNewTLSTestServer(t, certPath, keyPath, http.HandlerFunc(h)) + u = server.URL configURL := server.URL + "/my-org" auth := &actions.ActionsAuth{ Token: "token", } - client, err := actions.NewClient(ctx, configURL, auth) - assert.Nil(t, client) + client, err := actions.NewClient(configURL, auth) + require.NoError(t, err) + require.NotNil(t, client) + + err = client.RemoveRunner(ctx, 1) require.NotNil(t, err) if runtime.GOOS == "linux" { @@ -78,6 +85,7 @@ func TestServerWithSelfSignedCertificates(t *testing.T) { t.Run("client with ca certs", func(t *testing.T) { server := startNewTLSTestServer(t, certPath, keyPath, http.HandlerFunc(h)) + u = server.URL configURL := server.URL + "/my-org" auth := &actions.ActionsAuth{ @@ -90,9 +98,12 @@ func TestServerWithSelfSignedCertificates(t *testing.T) { pool, err := actions.RootCAsFromConfigMap(map[string][]byte{"cert": cert}) require.NoError(t, err) - client, err := actions.NewClient(ctx, configURL, auth, actions.WithRootCAs(pool)) + client, err := actions.NewClient(configURL, auth, actions.WithRootCAs(pool)) require.NoError(t, err) assert.NotNil(t, client) + + err = client.RemoveRunner(ctx, 1) + assert.NoError(t, err) }) t.Run("client with ca chain certs", func(t *testing.T) { @@ -102,6 +113,7 @@ func TestServerWithSelfSignedCertificates(t *testing.T) { filepath.Join("testdata", "leaf.key"), http.HandlerFunc(h), ) + u = server.URL configURL := server.URL + "/my-org" auth := &actions.ActionsAuth{ @@ -114,9 +126,12 @@ func TestServerWithSelfSignedCertificates(t *testing.T) { pool, err := actions.RootCAsFromConfigMap(map[string][]byte{"cert": cert}) require.NoError(t, err) - client, err := actions.NewClient(ctx, configURL, auth, actions.WithRootCAs(pool), actions.WithRetryMax(0)) + client, err := actions.NewClient(configURL, auth, actions.WithRootCAs(pool), actions.WithRetryMax(0)) require.NoError(t, err) - assert.NotNil(t, client) + require.NotNil(t, client) + + err = client.RemoveRunner(ctx, 1) + assert.NoError(t, err) }) t.Run("client skipping tls verification", func(t *testing.T) { @@ -127,7 +142,7 @@ func TestServerWithSelfSignedCertificates(t *testing.T) { Token: "token", } - client, err := actions.NewClient(ctx, configURL, auth, actions.WithoutTLSVerify()) + client, err := actions.NewClient(configURL, auth, actions.WithoutTLSVerify()) require.NoError(t, err) assert.NotNil(t, client) }) diff --git a/github/actions/config.go b/github/actions/config.go new file mode 100644 index 00000000..204fa0a4 --- /dev/null +++ b/github/actions/config.go @@ -0,0 +1,98 @@ +package actions + +import ( + "fmt" + "net/url" + "strings" +) + +var ErrInvalidGitHubConfigURL = fmt.Errorf("invalid config URL, should point to an enterprise, org, or repository") + +type GitHubScope int + +const ( + GitHubScopeUnknown GitHubScope = iota + GitHubScopeEnterprise + GitHubScopeOrganization + GitHubScopeRepository +) + +type GitHubConfig struct { + ConfigURL *url.URL + Scope GitHubScope + + Enterprise string + Organization string + Repository string + + IsHosted bool +} + +func ParseGitHubConfigFromURL(in string) (*GitHubConfig, error) { + u, err := url.Parse(in) + if err != nil { + return nil, err + } + + isHosted := u.Host == "github.com" || + u.Host == "www.github.com" || + u.Host == "github.localhost" + + configURL := &GitHubConfig{ + ConfigURL: u, + IsHosted: isHosted, + } + + invalidURLError := fmt.Errorf("%q: %w", u.String(), ErrInvalidGitHubConfigURL) + + pathParts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/") + + switch len(pathParts) { + case 1: // Organization + if pathParts[0] == "" { + return nil, invalidURLError + } + + configURL.Scope = GitHubScopeOrganization + configURL.Organization = pathParts[0] + + case 2: // Repository or enterprise + if strings.ToLower(pathParts[0]) == "enterprises" { + configURL.Scope = GitHubScopeEnterprise + configURL.Enterprise = pathParts[1] + break + } + + configURL.Scope = GitHubScopeRepository + configURL.Organization = pathParts[0] + configURL.Repository = pathParts[1] + default: + return nil, invalidURLError + } + + return configURL, nil +} + +func (c *GitHubConfig) GitHubAPIURL(path string) *url.URL { + result := &url.URL{ + Scheme: c.ConfigURL.Scheme, + } + + switch c.ConfigURL.Host { + // Hosted + case "github.com", "github.localhost": + result.Host = fmt.Sprintf("api.%s", c.ConfigURL.Host) + // re-routing www.github.com to api.github.com + case "www.github.com": + result.Host = "api.github.com" + + // Enterprise + default: + result.Host = c.ConfigURL.Host + result.Path = "/api/v3" + } + + result.Path += path + + return result +} diff --git a/github/actions/config_test.go b/github/actions/config_test.go new file mode 100644 index 00000000..a9a8368f --- /dev/null +++ b/github/actions/config_test.go @@ -0,0 +1,117 @@ +package actions_test + +import ( + "errors" + "net/url" + "testing" + + "github.com/actions/actions-runner-controller/github/actions" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGitHubConfig(t *testing.T) { + t.Run("when given a valid URL", func(t *testing.T) { + tests := []struct { + configURL string + expected *actions.GitHubConfig + }{ + { + configURL: "https://github.com/org/repo", + expected: &actions.GitHubConfig{ + Scope: actions.GitHubScopeRepository, + Enterprise: "", + Organization: "org", + Repository: "repo", + IsHosted: true, + }, + }, + { + configURL: "https://github.com/org", + expected: &actions.GitHubConfig{ + Scope: actions.GitHubScopeOrganization, + Enterprise: "", + Organization: "org", + Repository: "", + IsHosted: true, + }, + }, + { + configURL: "https://github.com/enterprises/my-enterprise", + expected: &actions.GitHubConfig{ + Scope: actions.GitHubScopeEnterprise, + Enterprise: "my-enterprise", + Organization: "", + Repository: "", + IsHosted: true, + }, + }, + { + configURL: "https://www.github.com/org", + expected: &actions.GitHubConfig{ + Scope: actions.GitHubScopeOrganization, + Enterprise: "", + Organization: "org", + Repository: "", + IsHosted: true, + }, + }, + { + configURL: "https://github.localhost/org", + expected: &actions.GitHubConfig{ + Scope: actions.GitHubScopeOrganization, + Enterprise: "", + Organization: "org", + Repository: "", + IsHosted: true, + }, + }, + { + configURL: "https://my-ghes.com/org", + expected: &actions.GitHubConfig{ + Scope: actions.GitHubScopeOrganization, + Enterprise: "", + Organization: "org", + Repository: "", + IsHosted: false, + }, + }, + } + + for _, test := range tests { + t.Run(test.configURL, func(t *testing.T) { + parsedURL, err := url.Parse(test.configURL) + require.NoError(t, err) + test.expected.ConfigURL = parsedURL + + cfg, err := actions.ParseGitHubConfigFromURL(test.configURL) + require.NoError(t, err) + assert.Equal(t, test.expected, cfg) + }) + } + }) + + t.Run("when given an invalid URL", func(t *testing.T) {}) + invalidURLs := []string{ + "https://github.com/", + "https://github.com", + "https://github.com/some/random/path", + } + + for _, u := range invalidURLs { + _, err := actions.ParseGitHubConfigFromURL(u) + require.Error(t, err) + assert.True(t, errors.Is(err, actions.ErrInvalidGitHubConfigURL)) + } +} + +func TestGitHubConfig_GitHubAPIURL(t *testing.T) { + t.Run("when hosted", func(t *testing.T) { + config, err := actions.ParseGitHubConfigFromURL("https://github.com/org/repo") + require.NoError(t, err) + + result := config.GitHubAPIURL("/some/path") + assert.Equal(t, "https://api.github.com/some/path", result.String()) + }) + t.Run("when not hosted", func(t *testing.T) {}) +} diff --git a/github/actions/github_api_request_test.go b/github/actions/github_api_request_test.go new file mode 100644 index 00000000..3a378149 --- /dev/null +++ b/github/actions/github_api_request_test.go @@ -0,0 +1,171 @@ +package actions_test + +import ( + "context" + "io" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/actions/actions-runner-controller/github/actions" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewGitHubAPIRequest(t *testing.T) { + ctx := context.Background() + + t.Run("uses the right host/path prefix", func(t *testing.T) { + scenarios := []struct { + configURL string + path string + expected string + }{ + { + configURL: "https://github.com/org/repo", + path: "/app/installations/123/access_tokens", + expected: "https://api.github.com/app/installations/123/access_tokens", + }, + { + configURL: "https://www.github.com/org/repo", + path: "/app/installations/123/access_tokens", + expected: "https://api.github.com/app/installations/123/access_tokens", + }, + { + configURL: "http://github.localhost/org/repo", + path: "/app/installations/123/access_tokens", + expected: "http://api.github.localhost/app/installations/123/access_tokens", + }, + { + configURL: "https://my-instance.com/org/repo", + path: "/app/installations/123/access_tokens", + expected: "https://my-instance.com/api/v3/app/installations/123/access_tokens", + }, + { + configURL: "http://localhost/org/repo", + path: "/app/installations/123/access_tokens", + expected: "http://localhost/api/v3/app/installations/123/access_tokens", + }, + } + + for _, scenario := range scenarios { + client, err := actions.NewClient(scenario.configURL, nil) + require.NoError(t, err) + + req, err := client.NewGitHubAPIRequest(ctx, http.MethodGet, scenario.path, nil) + require.NoError(t, err) + assert.Equal(t, scenario.expected, req.URL.String()) + } + }) + + t.Run("sets user agent header if present", func(t *testing.T) { + client, err := actions.NewClient("http://localhost/my-org", nil, actions.WithUserAgent("my-agent")) + require.NoError(t, err) + + req, err := client.NewGitHubAPIRequest(ctx, http.MethodGet, "/app/installations/123/access_tokens", nil) + require.NoError(t, err) + + assert.Equal(t, "my-agent", req.Header.Get("User-Agent")) + }) + + t.Run("sets the body we pass", func(t *testing.T) { + client, err := actions.NewClient("http://localhost/my-org", nil) + require.NoError(t, err) + + req, err := client.NewGitHubAPIRequest( + ctx, + http.MethodGet, + "/app/installations/123/access_tokens", + strings.NewReader("the-body"), + ) + require.NoError(t, err) + + b, err := io.ReadAll(req.Body) + require.NoError(t, err) + assert.Equal(t, "the-body", string(b)) + }) +} + +func TestNewActionsServiceRequest(t *testing.T) { + ctx := context.Background() + defaultCreds := &actions.ActionsAuth{Token: "token"} + + t.Run("manages authentication", func(t *testing.T) { + t.Run("client is brand new", func(t *testing.T) { + token := defaultActionsToken(t) + server := newActionsServer(t, nil, withActionsToken(token)) + + client, err := actions.NewClient(server.configURLForOrg("my-org"), defaultCreds) + require.NoError(t, err) + + req, err := client.NewActionsServiceRequest(ctx, http.MethodGet, "my-path", nil) + require.NoError(t, err) + + assert.Equal(t, "Bearer "+token, req.Header.Get("Authorization")) + }) + + t.Run("admin token is about to expire", func(t *testing.T) { + newToken := defaultActionsToken(t) + server := newActionsServer(t, nil, withActionsToken(newToken)) + + client, err := actions.NewClient(server.configURLForOrg("my-org"), defaultCreds) + require.NoError(t, err) + client.ActionsServiceAdminToken = "expiring-token" + client.ActionsServiceAdminTokenExpiresAt = time.Now().Add(59 * time.Second) + + req, err := client.NewActionsServiceRequest(ctx, http.MethodGet, "my-path", nil) + require.NoError(t, err) + + assert.Equal(t, "Bearer "+newToken, req.Header.Get("Authorization")) + }) + + t.Run("token is currently valid", func(t *testing.T) { + tokenThatShouldNotBeFetched := defaultActionsToken(t) + server := newActionsServer(t, nil, withActionsToken(tokenThatShouldNotBeFetched)) + + client, err := actions.NewClient(server.configURLForOrg("my-org"), defaultCreds) + require.NoError(t, err) + client.ActionsServiceAdminToken = "healthy-token" + client.ActionsServiceAdminTokenExpiresAt = time.Now().Add(1 * time.Hour) + + req, err := client.NewActionsServiceRequest(ctx, http.MethodGet, "my-path", nil) + require.NoError(t, err) + + assert.Equal(t, "Bearer healthy-token", req.Header.Get("Authorization")) + }) + }) + + t.Run("builds the right URL including api version", func(t *testing.T) { + server := newActionsServer(t, nil) + + client, err := actions.NewClient(server.configURLForOrg("my-org"), defaultCreds) + require.NoError(t, err) + + req, err := client.NewActionsServiceRequest(ctx, http.MethodGet, "/my/path?name=banana", nil) + require.NoError(t, err) + + serverURL, err := url.Parse(server.URL) + require.NoError(t, err) + + result := req.URL + assert.Equal(t, serverURL.Host, result.Host) + assert.Equal(t, "/tenant/123/my/path", result.Path) + assert.Equal(t, "banana", result.Query().Get("name")) + assert.Equal(t, "6.0-preview", result.Query().Get("api-version")) + }) + + t.Run("populates header", func(t *testing.T) { + server := newActionsServer(t, nil) + + client, err := actions.NewClient(server.configURLForOrg("my-org"), defaultCreds, actions.WithUserAgent("my-agent")) + require.NoError(t, err) + + req, err := client.NewActionsServiceRequest(ctx, http.MethodGet, "/my/path", nil) + require.NoError(t, err) + + assert.Equal(t, "my-agent", req.Header.Get("User-Agent")) + assert.Equal(t, "application/json", req.Header.Get("Content-Type")) + }) +} diff --git a/github/actions/multi_client.go b/github/actions/multi_client.go index 85e0fa75..b875c872 100644 --- a/github/actions/multi_client.go +++ b/github/actions/multi_client.go @@ -106,7 +106,6 @@ func (m *multiClient) GetClientFor(ctx context.Context, githubConfigURL string, m.logger.Info("creating new client", "githubConfigURL", githubConfigURL, "namespace", namespace) client, err := NewClient( - ctx, githubConfigURL, &creds, WithUserAgent(m.userAgent), diff --git a/github/actions/url_test.go b/github/actions/url_test.go deleted file mode 100644 index ae296a30..00000000 --- a/github/actions/url_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package actions - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestGithubAPIURL(t *testing.T) { - tests := []struct { - configURL string - path string - expected string - }{ - { - configURL: "https://github.com/org/repo", - path: "/app/installations/123/access_tokens", - expected: "https://api.github.com/app/installations/123/access_tokens", - }, - { - configURL: "https://www.github.com/org/repo", - path: "/app/installations/123/access_tokens", - expected: "https://api.github.com/app/installations/123/access_tokens", - }, - { - configURL: "http://github.localhost/org/repo", - path: "/app/installations/123/access_tokens", - expected: "http://api.github.localhost/app/installations/123/access_tokens", - }, - { - configURL: "https://my-instance.com/org/repo", - path: "/app/installations/123/access_tokens", - expected: "https://my-instance.com/api/v3/app/installations/123/access_tokens", - }, - { - configURL: "http://localhost/org/repo", - path: "/app/installations/123/access_tokens", - expected: "http://localhost/api/v3/app/installations/123/access_tokens", - }, - } - - for _, test := range tests { - actual, err := githubAPIURL(test.configURL, test.path) - require.NoError(t, err) - assert.Equal(t, test.expected, actual) - } -}