diff --git a/github/actions/client.go b/github/actions/client.go index fe164330..fd99dd21 100644 --- a/github/actions/client.go +++ b/github/actions/client.go @@ -842,23 +842,17 @@ type accessToken struct { } func (c *Client) fetchAccessToken(ctx context.Context, gitHubConfigURL string, creds *GitHubAppAuth) (*accessToken, error) { - parsedGitHubConfigURL, err := url.Parse(gitHubConfigURL) - if err != nil { - return nil, err - } - accessTokenJWT, err := createJWTForGitHubApp(creds) if err != nil { return nil, err } - ru := fmt.Sprintf("%v://%v/app/installations/%v/access_tokens", parsedGitHubConfigURL.Scheme, parsedGitHubConfigURL.Host, creds.AppInstallationID) - accessTokenURL, err := url.Parse(ru) + u, err := githubAPIURL(gitHubConfigURL, fmt.Sprintf("/app/installations/%v/access_tokens", creds.AppInstallationID)) if err != nil { return nil, err } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, accessTokenURL.String(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, nil) if err != nil { return nil, err } @@ -867,7 +861,7 @@ func (c *Client) fetchAccessToken(ctx context.Context, gitHubConfigURL string, c req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessTokenJWT)) req.Header.Add("User-Agent", c.userAgent) - c.logger.Info("getting access token for GitHub App auth", "accessTokenURL", accessTokenURL.String()) + c.logger.Info("getting access token for GitHub App auth", "accessTokenURL", u) resp, err := c.Do(req) if err != nil { @@ -1090,3 +1084,32 @@ func (c *Client) refreshTokenIfNeeded(ctx context.Context) error { return nil } + +func githubAPIURL(configURL, path string) (string, error) { + u, err := url.Parse(configURL) + if err != nil { + return "", err + } + + result := &url.URL{ + Scheme: u.Scheme, + } + + switch u.Host { + // Hosted + case "github.com", "github.localhost": + result.Host = fmt.Sprintf("api.%s", u.Host) + // re-routing www.github.com to api.github.com + case "www.github.com": + result.Host = "api.github.com" + + // Enterprise + default: + result.Host = u.Host + result.Path = "/api/v3" + } + + result.Path += path + + return result.String(), nil +} diff --git a/github/actions/url_test.go b/github/actions/url_test.go new file mode 100644 index 00000000..ae296a30 --- /dev/null +++ b/github/actions/url_test.go @@ -0,0 +1,48 @@ +package actions + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGithubAPIURL(t *testing.T) { + tests := []struct { + configURL string + path string + expected string + }{ + { + configURL: "https://github.com/org/repo", + path: "/app/installations/123/access_tokens", + expected: "https://api.github.com/app/installations/123/access_tokens", + }, + { + configURL: "https://www.github.com/org/repo", + path: "/app/installations/123/access_tokens", + expected: "https://api.github.com/app/installations/123/access_tokens", + }, + { + configURL: "http://github.localhost/org/repo", + path: "/app/installations/123/access_tokens", + expected: "http://api.github.localhost/app/installations/123/access_tokens", + }, + { + configURL: "https://my-instance.com/org/repo", + path: "/app/installations/123/access_tokens", + expected: "https://my-instance.com/api/v3/app/installations/123/access_tokens", + }, + { + configURL: "http://localhost/org/repo", + path: "/app/installations/123/access_tokens", + expected: "http://localhost/api/v3/app/installations/123/access_tokens", + }, + } + + for _, test := range tests { + actual, err := githubAPIURL(test.configURL, test.path) + require.NoError(t, err) + assert.Equal(t, test.expected, actual) + } +}