Remove network requests from actions.NewClient (#2219)

Co-authored-by: Nikola Jokic <jokicnikola07@gmail.com>
This commit is contained in:
Francesco Renzi 2023-01-31 10:55:23 +00:00 committed by GitHub
parent cc26593a9b
commit df12e00c9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 794 additions and 517 deletions

View File

@ -85,7 +85,6 @@ func run(rc RunnerScaleSetListenerConfig, logger logr.Logger) error {
}
actionsServiceClient, err := actions.NewClient(
ctx,
rc.ConfigureUrl,
creds,
actions.WithUserAgent(fmt.Sprintf("actions-runner-controller/%s", build.Version)),

View File

@ -17,9 +17,20 @@ import (
// /actions/runner-registration endpoints will be handled by the provided
// handler. The returned server is started and will be automatically closed
// when the test ends.
func newActionsServer(t *testing.T, handler http.Handler) *actionsServer {
var u string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
func newActionsServer(t *testing.T, handler http.Handler, options ...actionsServerOption) *actionsServer {
s := httptest.NewServer(nil)
server := &actionsServer{
Server: s,
}
t.Cleanup(func() {
server.Close()
})
for _, option := range options {
option(server)
}
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// handle getRunnerRegistrationToken
if strings.HasSuffix(r.URL.Path, "/runners/registration-token") {
w.WriteHeader(http.StatusCreated)
@ -29,41 +40,55 @@ func newActionsServer(t *testing.T, handler http.Handler) *actionsServer {
// handle getActionsServiceAdminConnection
if strings.HasSuffix(r.URL.Path, "/actions/runner-registration") {
claims := &jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Minute)),
Issuer: "123",
if server.token == "" {
server.token = defaultActionsToken(t)
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(samplePrivateKey))
require.NoError(t, err)
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
w.Write([]byte(`{"url":"` + u + `","token":"` + tokenString + `"}`))
w.Write([]byte(`{"url":"` + s.URL + `/tenant/123/","token":"` + server.token + `"}`))
return
}
handler.ServeHTTP(w, r)
}))
u = server.URL
t.Cleanup(func() {
server.Close()
})
return &actionsServer{server}
server.Config.Handler = h
return server
}
type actionsServerOption func(*actionsServer)
func withActionsToken(token string) actionsServerOption {
return func(s *actionsServer) {
s.token = token
}
}
type actionsServer struct {
*httptest.Server
token string
}
func (s *actionsServer) configURLForOrg(org string) string {
return s.URL + "/" + org
}
func defaultActionsToken(t *testing.T) string {
claims := &jwt.RegisteredClaims{
IssuedAt: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute)),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)),
Issuer: "123",
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(samplePrivateKey))
require.NoError(t, err)
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}
const samplePrivateKey = `-----BEGIN RSA PRIVATE KEY-----
MIICWgIBAAKBgHXfRT9cv9UY9fAAD4+1RshpfSSZe277urfEmPfX3/Og9zJYRk//
CZrJVD1CaBZDiIyQsNEzjta7r4UsqWdFOggiNN2E7ZTFQjMSaFkVgrzHqWuiaCBf

View File

@ -0,0 +1,61 @@
package actions_test
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/actions/actions-runner-controller/github/actions"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestClient_Do(t *testing.T) {
t.Run("trims byte order mark from response if present", func(t *testing.T) {
t.Run("when there is no body", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer server.Close()
client, err := actions.NewClient("https://localhost/org/repo", &actions.ActionsAuth{Token: "token"})
require.NoError(t, err)
req, err := http.NewRequest("GET", server.URL, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Empty(t, string(body))
})
responses := []string{
"\xef\xbb\xbf{\"foo\":\"bar\"}",
"{\"foo\":\"bar\"}",
}
for _, response := range responses {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(response))
}))
defer server.Close()
client, err := actions.NewClient("https://localhost/org/repo", &actions.ActionsAuth{Token: "token"})
require.NoError(t, err)
req, err := http.NewRequest("GET", server.URL, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "{\"foo\":\"bar\"}", string(body))
}
})
}

View File

@ -12,9 +12,7 @@ import (
"log"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"sync"
"time"
@ -62,17 +60,17 @@ type Client struct {
mu sync.Mutex
// TODO: Convert to unexported fields once refactor of Listener is complete
ActionsServiceAdminToken *string
ActionsServiceAdminTokenExpiresAt *time.Time
ActionsServiceURL *string
ActionsServiceAdminToken string
ActionsServiceAdminTokenExpiresAt time.Time
ActionsServiceURL string
retryMax int
retryWaitMax time.Duration
creds *ActionsAuth
githubConfigURL string
logger logr.Logger
userAgent string
creds *ActionsAuth
config *GitHubConfig
logger logr.Logger
userAgent string
rootCAs *x509.CertPool
tlsInsecureSkipVerify bool
@ -116,11 +114,16 @@ func WithoutTLSVerify() ClientOption {
}
}
func NewClient(ctx context.Context, githubConfigURL string, creds *ActionsAuth, options ...ClientOption) (ActionsService, error) {
func NewClient(githubConfigURL string, creds *ActionsAuth, options ...ClientOption) (*Client, error) {
config, err := ParseGitHubConfigFromURL(githubConfigURL)
if err != nil {
return nil, fmt.Errorf("failed to parse githubConfigURL: %w", err)
}
ac := &Client{
creds: creds,
githubConfigURL: githubConfigURL,
logger: logr.Discard(),
creds: creds,
config: config,
logger: logr.Discard(),
// retryablehttp defaults
retryMax: 4,
@ -132,9 +135,6 @@ func NewClient(ctx context.Context, githubConfigURL string, creds *ActionsAuth,
}
retryClient := retryablehttp.NewClient()
// TODO: this silences retryclient default logger, do we want to provide one
// instead? by default retryablehttp logs all requests to stderr
retryClient.Logger = log.New(io.Discard, "", log.LstdFlags)
retryClient.RetryMax = ac.retryMax
@ -161,48 +161,94 @@ func NewClient(ctx context.Context, githubConfigURL string, creds *ActionsAuth,
retryClient.HTTPClient.Transport = transport
ac.Client = retryClient.StandardClient()
rt, err := ac.getRunnerRegistrationToken(ctx, githubConfigURL, *creds)
if err != nil {
return nil, fmt.Errorf("failed to get runner registration token: %w", err)
}
adminConnInfo, err := ac.getActionsServiceAdminConnection(ctx, rt, githubConfigURL)
if err != nil {
return nil, fmt.Errorf("failed to get actions service admin connection: %w", err)
}
ac.ActionsServiceURL = adminConnInfo.ActionsServiceUrl
ac.mu.Lock()
defer ac.mu.Unlock()
ac.ActionsServiceAdminToken = adminConnInfo.AdminToken
ac.ActionsServiceAdminTokenExpiresAt, err = actionsServiceAdminTokenExpiresAt(*adminConnInfo.AdminToken)
if err != nil {
return nil, fmt.Errorf("failed to get admin token expire at: %w", err)
}
return ac, nil
}
func (c *Client) GetRunnerScaleSet(ctx context.Context, runnerScaleSetName string) (*RunnerScaleSet, error) {
u := fmt.Sprintf("%s/%s?name=%s&api-version=6.0-preview", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetName)
if err := c.refreshTokenIfNeeded(ctx); err != nil {
return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err)
func (c *Client) Do(req *http.Request) (*http.Response, error) {
resp, err := c.Client.Do(req)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
err = resp.Body.Close()
if err != nil {
return nil, err
}
body = trimByteOrderMark(body)
resp.Body = io.NopCloser(bytes.NewReader(body))
return resp, nil
}
func (c *Client) NewGitHubAPIRequest(ctx context.Context, method, path string, body io.Reader) (*http.Request, error) {
u := c.config.GitHubAPIURL(path)
req, err := http.NewRequestWithContext(ctx, method, u.String(), body)
if err != nil {
return nil, err
}
if c.userAgent != "" {
req.Header.Set("User-Agent", c.userAgent)
}
return req, nil
}
func (c *Client) NewActionsServiceRequest(ctx context.Context, method, path string, body io.Reader) (*http.Request, error) {
err := c.updateTokenIfNeeded(ctx)
if err != nil {
return nil, err
}
parsedPath, err := url.Parse(path)
if err != nil {
return nil, err
}
urlString, err := url.JoinPath(c.ActionsServiceURL, parsedPath.Path)
if err != nil {
return nil, err
}
u, err := url.Parse(urlString)
if err != nil {
return nil, err
}
q := u.Query()
for k, v := range parsedPath.Query() {
q[k] = v
}
if q.Get("api-version") == "" {
q.Set("api-version", "6.0-preview")
}
u.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(ctx, method, u.String(), body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.ActionsServiceAdminToken))
if c.userAgent != "" {
req.Header.Set("User-Agent", c.userAgent)
}
return req, nil
}
func (c *Client) GetRunnerScaleSet(ctx context.Context, runnerScaleSetName string) (*RunnerScaleSet, error) {
path := fmt.Sprintf("/%s?name=%s", scaleSetEndpoint, runnerScaleSetName)
req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil)
if err != nil {
return nil, err
}
resp, err := c.Do(req)
if err != nil {
return nil, err
@ -211,8 +257,9 @@ func (c *Client) GetRunnerScaleSet(ctx context.Context, runnerScaleSetName strin
if resp.StatusCode != http.StatusOK {
return nil, ParseActionsErrorFromResponse(resp)
}
var runnerScaleSetList *runnerScaleSetsResponse
err = unmarshalBody(resp, &runnerScaleSetList)
err = json.NewDecoder(resp.Body).Decode(&runnerScaleSetList)
if err != nil {
return nil, err
}
@ -227,24 +274,12 @@ func (c *Client) GetRunnerScaleSet(ctx context.Context, runnerScaleSetName strin
}
func (c *Client) GetRunnerScaleSetById(ctx context.Context, runnerScaleSetId int) (*RunnerScaleSet, error) {
u := fmt.Sprintf("%s/%s/%d?api-version=6.0-preview", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId)
if err := c.refreshTokenIfNeeded(ctx); err != nil {
return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
path := fmt.Sprintf("/%s/%d", scaleSetEndpoint, runnerScaleSetId)
req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken))
if c.userAgent != "" {
req.Header.Set("User-Agent", c.userAgent)
}
resp, err := c.Do(req)
if err != nil {
return nil, err
@ -255,7 +290,7 @@ func (c *Client) GetRunnerScaleSetById(ctx context.Context, runnerScaleSetId int
}
var runnerScaleSet *RunnerScaleSet
err = unmarshalBody(resp, &runnerScaleSet)
err = json.NewDecoder(resp.Body).Decode(&runnerScaleSet)
if err != nil {
return nil, err
}
@ -263,24 +298,12 @@ func (c *Client) GetRunnerScaleSetById(ctx context.Context, runnerScaleSetId int
}
func (c *Client) GetRunnerGroupByName(ctx context.Context, runnerGroup string) (*RunnerGroup, error) {
u := fmt.Sprintf("%s/_apis/runtime/runnergroups/?groupName=%s&api-version=6.0-preview", *c.ActionsServiceURL, runnerGroup)
if err := c.refreshTokenIfNeeded(ctx); err != nil {
return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
path := fmt.Sprintf("/_apis/runtime/runnergroups/?groupName=%s", runnerGroup)
req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken))
if c.userAgent != "" {
req.Header.Set("User-Agent", c.userAgent)
}
resp, err := c.Do(req)
if err != nil {
return nil, err
@ -295,7 +318,7 @@ func (c *Client) GetRunnerGroupByName(ctx context.Context, runnerGroup string) (
}
var runnerGroupList *RunnerGroupList
err = unmarshalBody(resp, &runnerGroupList)
err = json.NewDecoder(resp.Body).Decode(&runnerGroupList)
if err != nil {
return nil, err
}
@ -312,29 +335,16 @@ func (c *Client) GetRunnerGroupByName(ctx context.Context, runnerGroup string) (
}
func (c *Client) CreateRunnerScaleSet(ctx context.Context, runnerScaleSet *RunnerScaleSet) (*RunnerScaleSet, error) {
u := fmt.Sprintf("%s/%s?api-version=6.0-preview", *c.ActionsServiceURL, scaleSetEndpoint)
if err := c.refreshTokenIfNeeded(ctx); err != nil {
return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err)
}
body, err := json.Marshal(runnerScaleSet)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, bytes.NewBuffer(body))
req, err := c.NewActionsServiceRequest(ctx, http.MethodPost, scaleSetEndpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken))
if c.userAgent != "" {
req.Header.Set("User-Agent", c.userAgent)
}
resp, err := c.Do(req)
if err != nil {
return nil, err
@ -344,7 +354,7 @@ func (c *Client) CreateRunnerScaleSet(ctx context.Context, runnerScaleSet *Runne
return nil, ParseActionsErrorFromResponse(resp)
}
var createdRunnerScaleSet *RunnerScaleSet
err = unmarshalBody(resp, &createdRunnerScaleSet)
err = json.NewDecoder(resp.Body).Decode(&createdRunnerScaleSet)
if err != nil {
return nil, err
}
@ -352,29 +362,18 @@ func (c *Client) CreateRunnerScaleSet(ctx context.Context, runnerScaleSet *Runne
}
func (c *Client) UpdateRunnerScaleSet(ctx context.Context, runnerScaleSetId int, runnerScaleSet *RunnerScaleSet) (*RunnerScaleSet, error) {
u := fmt.Sprintf("%s/%s/%d?api-version=6.0-preview", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId)
if err := c.refreshTokenIfNeeded(ctx); err != nil {
return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err)
}
path := fmt.Sprintf("%s/%d", scaleSetEndpoint, runnerScaleSetId)
body, err := json.Marshal(runnerScaleSet)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, u, bytes.NewBuffer(body))
req, err := c.NewActionsServiceRequest(ctx, http.MethodPatch, path, bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken))
if c.userAgent != "" {
req.Header.Set("User-Agent", c.userAgent)
}
resp, err := c.Do(req)
if err != nil {
return nil, err
@ -385,7 +384,7 @@ func (c *Client) UpdateRunnerScaleSet(ctx context.Context, runnerScaleSetId int,
}
var updatedRunnerScaleSet *RunnerScaleSet
err = unmarshalBody(resp, &updatedRunnerScaleSet)
err = json.NewDecoder(resp.Body).Decode(&updatedRunnerScaleSet)
if err != nil {
return nil, err
}
@ -393,24 +392,12 @@ func (c *Client) UpdateRunnerScaleSet(ctx context.Context, runnerScaleSetId int,
}
func (c *Client) DeleteRunnerScaleSet(ctx context.Context, runnerScaleSetId int) error {
u := fmt.Sprintf("%s/%s/%d?api-version=6.0-preview", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId)
if err := c.refreshTokenIfNeeded(ctx); err != nil {
return fmt.Errorf("failed to refresh admin token if needed: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, u, nil)
path := fmt.Sprintf("/%s/%d", scaleSetEndpoint, runnerScaleSetId)
req, err := c.NewActionsServiceRequest(ctx, http.MethodDelete, path, nil)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken))
if c.userAgent != "" {
req.Header.Set("User-Agent", c.userAgent)
}
resp, err := c.Do(req)
if err != nil {
return err
@ -425,12 +412,18 @@ func (c *Client) DeleteRunnerScaleSet(ctx context.Context, runnerScaleSetId int)
}
func (c *Client) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64) (*RunnerScaleSetMessage, error) {
u := messageQueueUrl
if lastMessageId > 0 {
u = fmt.Sprintf("%s&lassMessageId=%d", messageQueueUrl, lastMessageId)
u, err := url.Parse(messageQueueUrl)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if lastMessageId > 0 {
q := u.Query()
q.Set("lastMessageId", strconv.FormatInt(lastMessageId, 10))
u.RawQuery = q.Encode()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return nil, err
}
@ -466,7 +459,7 @@ func (c *Client) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAc
}
var message *RunnerScaleSetMessage
err = unmarshalBody(resp, &message)
err = json.NewDecoder(resp.Body).Decode(&message)
if err != nil {
return nil, err
}
@ -514,7 +507,7 @@ func (c *Client) DeleteMessage(ctx context.Context, messageQueueUrl, messageQueu
}
func (c *Client) CreateMessageSession(ctx context.Context, runnerScaleSetId int, owner string) (*RunnerScaleSetSession, error) {
u := fmt.Sprintf("%v/%v/%v/sessions?%v", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId, apiVersionQueryParam)
path := fmt.Sprintf("/%s/%d/sessions", scaleSetEndpoint, runnerScaleSetId)
newSession := &RunnerScaleSetSession{
OwnerName: owner,
@ -527,49 +520,36 @@ func (c *Client) CreateMessageSession(ctx context.Context, runnerScaleSetId int,
createdSession := &RunnerScaleSetSession{}
err = c.doSessionRequest(ctx, http.MethodPost, u, bytes.NewBuffer(requestData), http.StatusOK, createdSession)
err = c.doSessionRequest(ctx, http.MethodPost, path, bytes.NewBuffer(requestData), http.StatusOK, createdSession)
return createdSession, err
}
func (c *Client) DeleteMessageSession(ctx context.Context, runnerScaleSetId int, sessionId *uuid.UUID) error {
u := fmt.Sprintf("%v/%v/%v/sessions/%v?%v", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId, sessionId.String(), apiVersionQueryParam)
return c.doSessionRequest(ctx, http.MethodDelete, u, nil, http.StatusNoContent, nil)
path := fmt.Sprintf("/%s/%d/sessions/%s", scaleSetEndpoint, runnerScaleSetId, sessionId.String())
return c.doSessionRequest(ctx, http.MethodDelete, path, nil, http.StatusNoContent, nil)
}
func (c *Client) RefreshMessageSession(ctx context.Context, runnerScaleSetId int, sessionId *uuid.UUID) (*RunnerScaleSetSession, error) {
u := fmt.Sprintf("%v/%v/%v/sessions/%v?%v", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId, sessionId.String(), apiVersionQueryParam)
path := fmt.Sprintf("/%s/%d/sessions/%s", scaleSetEndpoint, runnerScaleSetId, sessionId.String())
refreshedSession := &RunnerScaleSetSession{}
err := c.doSessionRequest(ctx, http.MethodPatch, u, nil, http.StatusOK, refreshedSession)
err := c.doSessionRequest(ctx, http.MethodPatch, path, nil, http.StatusOK, refreshedSession)
return refreshedSession, err
}
func (c *Client) doSessionRequest(ctx context.Context, method, url string, requestData io.Reader, expectedResponseStatusCode int, responseUnmarshalTarget any) error {
if err := c.refreshTokenIfNeeded(ctx); err != nil {
return fmt.Errorf("failed to refresh admin token if needed: %w", err)
}
req, err := http.NewRequestWithContext(ctx, method, url, requestData)
func (c *Client) doSessionRequest(ctx context.Context, method, path string, requestData io.Reader, expectedResponseStatusCode int, responseUnmarshalTarget any) error {
req, err := c.NewActionsServiceRequest(ctx, method, path, requestData)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken))
if c.userAgent != "" {
req.Header.Set("User-Agent", c.userAgent)
}
resp, err := c.Do(req)
if err != nil {
return err
}
if resp.StatusCode == expectedResponseStatusCode && responseUnmarshalTarget != nil {
err = unmarshalBody(resp, &responseUnmarshalTarget)
return err
return json.NewDecoder(resp.Body).Decode(responseUnmarshalTarget)
}
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
@ -587,7 +567,7 @@ func (c *Client) doSessionRequest(ctx context.Context, method, url string, reque
}
func (c *Client) AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQueueAccessToken string, requestIds []int64) ([]int64, error) {
u := fmt.Sprintf("%s/%s/%d/acquirejobs?api-version=6.0-preview", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId)
u := fmt.Sprintf("%s/%s/%d/acquirejobs?api-version=6.0-preview", c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId)
body, err := json.Marshal(requestIds)
if err != nil {
@ -614,8 +594,8 @@ func (c *Client) AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQ
return nil, ParseActionsErrorFromResponse(resp)
}
var acquiredJobs Int64List
err = unmarshalBody(resp, &acquiredJobs)
var acquiredJobs *Int64List
err = json.NewDecoder(resp.Body).Decode(&acquiredJobs)
if err != nil {
return nil, err
}
@ -624,24 +604,13 @@ func (c *Client) AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQ
}
func (c *Client) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (*AcquirableJobList, error) {
u := fmt.Sprintf("%s/%s/%d/acquirablejobs?api-version=6.0-preview", *c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId)
path := fmt.Sprintf("/%s/%d/acquirablejobs", scaleSetEndpoint, runnerScaleSetId)
if err := c.refreshTokenIfNeeded(ctx); err != nil {
return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken))
if c.userAgent != "" {
req.Header.Set("User-Agent", c.userAgent)
}
resp, err := c.Do(req)
if err != nil {
return nil, err
@ -657,7 +626,7 @@ func (c *Client) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (*
}
var acquirableJobList *AcquirableJobList
err = unmarshalBody(resp, &acquirableJobList)
err = json.NewDecoder(resp.Body).Decode(&acquirableJobList)
if err != nil {
return nil, err
}
@ -666,28 +635,18 @@ func (c *Client) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (*
}
func (c *Client) GenerateJitRunnerConfig(ctx context.Context, jitRunnerSetting *RunnerScaleSetJitRunnerSetting, scaleSetId int) (*RunnerScaleSetJitRunnerConfig, error) {
runnerJitConfigUrl := fmt.Sprintf("%s/%s/%d/generatejitconfig?api-version=6.0-preview", *c.ActionsServiceURL, scaleSetEndpoint, scaleSetId)
if err := c.refreshTokenIfNeeded(ctx); err != nil {
return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err)
}
path := fmt.Sprintf("/%s/%d/generatejitconfig", scaleSetEndpoint, scaleSetId)
body, err := json.Marshal(jitRunnerSetting)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, runnerJitConfigUrl, bytes.NewBuffer(body))
req, err := c.NewActionsServiceRequest(ctx, http.MethodPost, path, bytes.NewBuffer(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken))
if c.userAgent != "" {
req.Header.Set("User-Agent", c.userAgent)
}
resp, err := c.Do(req)
if err != nil {
return nil, err
@ -698,7 +657,7 @@ func (c *Client) GenerateJitRunnerConfig(ctx context.Context, jitRunnerSetting *
}
var runnerJitConfig *RunnerScaleSetJitRunnerConfig
err = unmarshalBody(resp, &runnerJitConfig)
err = json.NewDecoder(resp.Body).Decode(&runnerJitConfig)
if err != nil {
return nil, err
}
@ -706,24 +665,13 @@ func (c *Client) GenerateJitRunnerConfig(ctx context.Context, jitRunnerSetting *
}
func (c *Client) GetRunner(ctx context.Context, runnerId int64) (*RunnerReference, error) {
url := fmt.Sprintf("%v/%v/%v?%v", *c.ActionsServiceURL, runnerEndpoint, runnerId, apiVersionQueryParam)
path := fmt.Sprintf("/%s/%d", runnerEndpoint, runnerId)
if err := c.refreshTokenIfNeeded(ctx); err != nil {
return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken))
if c.userAgent != "" {
req.Header.Set("User-Agent", c.userAgent)
}
resp, err := c.Do(req)
if err != nil {
return nil, err
@ -734,7 +682,8 @@ func (c *Client) GetRunner(ctx context.Context, runnerId int64) (*RunnerReferenc
}
var runnerReference *RunnerReference
if err := unmarshalBody(resp, &runnerReference); err != nil {
err = json.NewDecoder(resp.Body).Decode(&runnerReference)
if err != nil {
return nil, err
}
@ -742,24 +691,13 @@ func (c *Client) GetRunner(ctx context.Context, runnerId int64) (*RunnerReferenc
}
func (c *Client) GetRunnerByName(ctx context.Context, runnerName string) (*RunnerReference, error) {
url := fmt.Sprintf("%v/%v?agentName=%v&%v", *c.ActionsServiceURL, runnerEndpoint, runnerName, apiVersionQueryParam)
path := fmt.Sprintf("/%s?agentName=%s", runnerEndpoint, runnerName)
if err := c.refreshTokenIfNeeded(ctx); err != nil {
return nil, fmt.Errorf("failed to refresh admin token if needed: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken))
if c.userAgent != "" {
req.Header.Set("User-Agent", c.userAgent)
}
resp, err := c.Do(req)
if err != nil {
return nil, err
@ -770,7 +708,7 @@ func (c *Client) GetRunnerByName(ctx context.Context, runnerName string) (*Runne
}
var runnerList *RunnerReferenceList
err = unmarshalBody(resp, &runnerList)
err = json.NewDecoder(resp.Body).Decode(&runnerList)
if err != nil {
return nil, err
}
@ -787,24 +725,13 @@ func (c *Client) GetRunnerByName(ctx context.Context, runnerName string) (*Runne
}
func (c *Client) RemoveRunner(ctx context.Context, runnerId int64) error {
url := fmt.Sprintf("%v/%v/%v?%v", *c.ActionsServiceURL, runnerEndpoint, runnerId, apiVersionQueryParam)
path := fmt.Sprintf("/%s/%d", runnerEndpoint, runnerId)
if err := c.refreshTokenIfNeeded(ctx); err != nil {
return fmt.Errorf("failed to refresh admin token if needed: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil)
req, err := c.NewActionsServiceRequest(ctx, http.MethodDelete, path, nil)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *c.ActionsServiceAdminToken))
if c.userAgent != "" {
req.Header.Set("User-Agent", c.userAgent)
}
resp, err := c.Do(req)
if err != nil {
return err
@ -823,25 +750,25 @@ type registrationToken struct {
ExpiresAt *time.Time `json:"expires_at,omitempty"`
}
func (c *Client) getRunnerRegistrationToken(ctx context.Context, githubConfigUrl string, creds ActionsAuth) (*registrationToken, error) {
registrationTokenURL, err := createRegistrationTokenURL(githubConfigUrl)
func (c *Client) getRunnerRegistrationToken(ctx context.Context) (*registrationToken, error) {
path, err := createRegistrationTokenPath(c.config)
if err != nil {
return nil, err
}
var buf bytes.Buffer
req, err := http.NewRequestWithContext(ctx, http.MethodPost, registrationTokenURL, &buf)
req, err := c.NewGitHubAPIRequest(ctx, http.MethodPost, path, &buf)
if err != nil {
return nil, err
}
bearerToken := ""
if creds.Token != "" {
encodedToken := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("github:%v", creds.Token)))
if c.creds.Token != "" {
encodedToken := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("github:%v", c.creds.Token)))
bearerToken = fmt.Sprintf("Basic %v", encodedToken)
} else {
accessToken, err := c.fetchAccessToken(ctx, githubConfigUrl, creds.AppCreds)
accessToken, err := c.fetchAccessToken(ctx, c.config.ConfigURL.String(), c.creds.AppCreds)
if err != nil {
return nil, err
}
@ -851,9 +778,8 @@ func (c *Client) getRunnerRegistrationToken(ctx context.Context, githubConfigUrl
req.Header.Set("Content-Type", "application/vnd.github.v3+json")
req.Header.Set("Authorization", bearerToken)
req.Header.Set("User-Agent", c.userAgent)
c.logger.Info("getting runner registration token", "registrationTokenURL", registrationTokenURL)
c.logger.Info("getting runner registration token", "registrationTokenURL", req.URL.String())
resp, err := c.Do(req)
if err != nil {
@ -869,8 +795,8 @@ func (c *Client) getRunnerRegistrationToken(ctx context.Context, githubConfigUrl
return nil, fmt.Errorf("unexpected response from Actions service during registration token call: %v - %v", resp.StatusCode, string(body))
}
registrationToken := &registrationToken{}
if err := json.NewDecoder(resp.Body).Decode(registrationToken); err != nil {
var registrationToken *registrationToken
if err := json.NewDecoder(resp.Body).Decode(&registrationToken); err != nil {
return nil, err
}
@ -889,21 +815,16 @@ func (c *Client) fetchAccessToken(ctx context.Context, gitHubConfigURL string, c
return nil, err
}
u, err := githubAPIURL(gitHubConfigURL, fmt.Sprintf("/app/installations/%v/access_tokens", creds.AppInstallationID))
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, nil)
path := fmt.Sprintf("/app/installations/%v/access_tokens", creds.AppInstallationID)
req, err := c.NewGitHubAPIRequest(ctx, http.MethodPost, path, nil)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/vnd.github+json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessTokenJWT))
req.Header.Add("User-Agent", c.userAgent)
c.logger.Info("getting access token for GitHub App auth", "accessTokenURL", u)
c.logger.Info("getting access token for GitHub App auth", "accessTokenURL", req.URL.String())
resp, err := c.Do(req)
if err != nil {
@ -912,8 +833,8 @@ func (c *Client) fetchAccessToken(ctx context.Context, gitHubConfigURL string, c
defer resp.Body.Close()
// Format: https://docs.github.com/en/rest/apps/apps#create-an-installation-access-token-for-an-app
accessToken := &accessToken{}
err = json.NewDecoder(resp.Body).Decode(accessToken)
var accessToken *accessToken
err = json.NewDecoder(resp.Body).Decode(&accessToken)
return accessToken, err
}
@ -922,27 +843,14 @@ type ActionsServiceAdminConnection struct {
AdminToken *string `json:"token,omitempty"`
}
func (c *Client) getActionsServiceAdminConnection(ctx context.Context, rt *registrationToken, githubConfigUrl string) (*ActionsServiceAdminConnection, error) {
parsedGitHubConfigURL, err := url.Parse(githubConfigUrl)
if err != nil {
return nil, err
}
if isHostedServer(*parsedGitHubConfigURL) {
parsedGitHubConfigURL.Host = fmt.Sprintf("api.%v", parsedGitHubConfigURL.Host)
}
ru := fmt.Sprintf("%v://%v/actions/runner-registration", parsedGitHubConfigURL.Scheme, parsedGitHubConfigURL.Host)
registrationURL, err := url.Parse(ru)
if err != nil {
return nil, err
}
func (c *Client) getActionsServiceAdminConnection(ctx context.Context, rt *registrationToken) (*ActionsServiceAdminConnection, error) {
path := "/actions/runner-registration"
body := struct {
Url string `json:"url"`
RunnerEvent string `json:"runner_event"`
}{
Url: githubConfigUrl,
Url: c.config.ConfigURL.String(),
RunnerEvent: "register",
}
@ -954,16 +862,15 @@ func (c *Client) getActionsServiceAdminConnection(ctx context.Context, rt *regis
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, registrationURL.String(), buf)
req, err := c.NewGitHubAPIRequest(ctx, http.MethodPost, path, buf)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("RemoteAuth %s", *rt.Token))
req.Header.Set("User-Agent", c.userAgent)
c.logger.Info("getting Actions tenant URL and JWT", "registrationURL", registrationURL.String())
c.logger.Info("getting Actions tenant URL and JWT", "registrationURL", req.URL.String())
resp, err := c.Do(req)
if err != nil {
@ -971,65 +878,30 @@ func (c *Client) getActionsServiceAdminConnection(ctx context.Context, rt *regis
}
defer resp.Body.Close()
actionsServiceAdminConnection := &ActionsServiceAdminConnection{}
if err := json.NewDecoder(resp.Body).Decode(actionsServiceAdminConnection); err != nil {
var actionsServiceAdminConnection *ActionsServiceAdminConnection
if err := json.NewDecoder(resp.Body).Decode(&actionsServiceAdminConnection); err != nil {
return nil, err
}
return actionsServiceAdminConnection, nil
}
func isHostedServer(gitHubURL url.URL) bool {
return gitHubURL.Host == "github.com" ||
gitHubURL.Host == "www.github.com" ||
gitHubURL.Host == "github.localhost"
}
func createRegistrationTokenPath(config *GitHubConfig) (string, error) {
switch config.Scope {
case GitHubScopeOrganization:
path := fmt.Sprintf("/orgs/%s/actions/runners/registration-token", config.Organization)
return path, nil
func createRegistrationTokenURL(githubConfigUrl string) (string, error) {
parsedGitHubConfigURL, err := url.Parse(githubConfigUrl)
if err != nil {
return "", err
}
case GitHubScopeEnterprise:
path := fmt.Sprintf("/enterprises/%s/actions/runners/registration-token", config.Enterprise)
return path, nil
// Check for empty path before split, because strings.Split will return a slice of length 1
// when the split delimiter is not present.
trimmedPath := strings.TrimLeft(parsedGitHubConfigURL.Path, "/")
if len(trimmedPath) == 0 {
return "", fmt.Errorf("%q should point to an enterprise, org, or repository", parsedGitHubConfigURL.String())
}
case GitHubScopeRepository:
path := fmt.Sprintf("/repos/%s/%s/actions/runners/registration-token", config.Organization, config.Repository)
return path, nil
pathParts := strings.Split(path.Clean(strings.TrimLeft(parsedGitHubConfigURL.Path, "/")), "/")
switch len(pathParts) {
case 1: // Organization
registrationTokenURL := fmt.Sprintf(
"%v://%v/api/v3/orgs/%v/actions/runners/registration-token",
parsedGitHubConfigURL.Scheme, parsedGitHubConfigURL.Host, pathParts[0])
if isHostedServer(*parsedGitHubConfigURL) {
registrationTokenURL = fmt.Sprintf(
"%v://api.%v/orgs/%v/actions/runners/registration-token",
parsedGitHubConfigURL.Scheme, parsedGitHubConfigURL.Host, pathParts[0])
}
return registrationTokenURL, nil
case 2: // Repository or enterprise
repoScope := "repos/"
if strings.ToLower(pathParts[0]) == "enterprises" {
repoScope = ""
}
registrationTokenURL := fmt.Sprintf("%v://%v/api/v3/%v%v/%v/actions/runners/registration-token",
parsedGitHubConfigURL.Scheme, parsedGitHubConfigURL.Host, repoScope, pathParts[0], pathParts[1])
if isHostedServer(*parsedGitHubConfigURL) {
registrationTokenURL = fmt.Sprintf("%v://api.%v/%v%v/%v/actions/runners/registration-token",
parsedGitHubConfigURL.Scheme, parsedGitHubConfigURL.Host, repoScope, pathParts[0], pathParts[1])
}
return registrationTokenURL, nil
default:
return "", fmt.Errorf("%q should point to an enterprise, org, or repository", parsedGitHubConfigURL.String())
return "", fmt.Errorf("unknown scope for config url: %s", config.ConfigURL)
}
}
@ -1057,68 +929,50 @@ func createJWTForGitHubApp(appAuth *GitHubAppAuth) (string, error) {
return token.SignedString(privateKey)
}
func unmarshalBody(response *http.Response, v interface{}) (err error) {
if response != nil && response.Body != nil {
var err error
defer func() {
if closeError := response.Body.Close(); closeError != nil {
err = closeError
}
}()
body, err := io.ReadAll(response.Body)
if err != nil {
return err
}
body = trimByteOrderMark(body)
return json.Unmarshal(body, &v)
}
return nil
}
// Returns slice of body without utf-8 byte order mark.
// If BOM does not exist body is returned unchanged.
func trimByteOrderMark(body []byte) []byte {
return bytes.TrimPrefix(body, []byte("\xef\xbb\xbf"))
}
func actionsServiceAdminTokenExpiresAt(jwtToken string) (*time.Time, error) {
func actionsServiceAdminTokenExpiresAt(jwtToken string) (time.Time, error) {
type JwtClaims struct {
jwt.RegisteredClaims
}
token, _, err := jwt.NewParser().ParseUnverified(jwtToken, &JwtClaims{})
if err != nil {
return nil, fmt.Errorf("failed to parse jwt token: %w", err)
return time.Time{}, fmt.Errorf("failed to parse jwt token: %w", err)
}
if claims, ok := token.Claims.(*JwtClaims); ok {
return &claims.ExpiresAt.Time, nil
return claims.ExpiresAt.Time, nil
}
return nil, fmt.Errorf("failed to parse token claims to get expire at")
return time.Time{}, fmt.Errorf("failed to parse token claims to get expire at")
}
func (c *Client) refreshTokenIfNeeded(ctx context.Context) error {
func (c *Client) updateTokenIfNeeded(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
aboutToExpire := time.Now().Add(60 * time.Second).After(*c.ActionsServiceAdminTokenExpiresAt)
if !aboutToExpire {
aboutToExpire := time.Now().Add(60 * time.Second).After(c.ActionsServiceAdminTokenExpiresAt)
if !aboutToExpire && !c.ActionsServiceAdminTokenExpiresAt.IsZero() {
return nil
}
c.logger.Info("Admin token is about to expire, refreshing it", "githubConfigUrl", c.githubConfigURL)
rt, err := c.getRunnerRegistrationToken(ctx, c.githubConfigURL, *c.creds)
c.logger.Info("refreshing token", "githubConfigUrl", c.config.ConfigURL.String())
rt, err := c.getRunnerRegistrationToken(ctx)
if err != nil {
return fmt.Errorf("failed to get runner registration token on fresh: %w", err)
return fmt.Errorf("failed to get runner registration token on refresh: %w", err)
}
adminConnInfo, err := c.getActionsServiceAdminConnection(ctx, rt, c.githubConfigURL)
adminConnInfo, err := c.getActionsServiceAdminConnection(ctx, rt)
if err != nil {
return fmt.Errorf("failed to get actions service admin connection on fresh: %w", err)
return fmt.Errorf("failed to get actions service admin connection on refresh: %w", err)
}
c.ActionsServiceURL = adminConnInfo.ActionsServiceUrl
c.ActionsServiceAdminToken = adminConnInfo.AdminToken
c.ActionsServiceURL = *adminConnInfo.ActionsServiceUrl
c.ActionsServiceAdminToken = *adminConnInfo.AdminToken
c.ActionsServiceAdminTokenExpiresAt, err = actionsServiceAdminTokenExpiresAt(*adminConnInfo.AdminToken)
if err != nil {
return fmt.Errorf("failed to get admin token expire at on refresh: %w", err)
@ -1126,32 +980,3 @@ func (c *Client) refreshTokenIfNeeded(ctx context.Context) error {
return nil
}
func githubAPIURL(configURL, path string) (string, error) {
u, err := url.Parse(configURL)
if err != nil {
return "", err
}
result := &url.URL{
Scheme: u.Scheme,
}
switch u.Host {
// Hosted
case "github.com", "github.localhost":
result.Host = fmt.Sprintf("api.%s", u.Host)
// re-routing www.github.com to api.github.com
case "www.github.com":
result.Host = "api.github.com"
// Enterprise
default:
result.Host = u.Host
result.Path = "/api/v3"
}
result.Path += path
return result.String(), nil
}

View File

@ -26,7 +26,7 @@ func TestGenerateJitRunnerConfig(t *testing.T) {
server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Write(response)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
got, err := client.GenerateJitRunnerConfig(ctx, runnerSettings, 1)
@ -47,7 +47,6 @@ func TestGenerateJitRunnerConfig(t *testing.T) {
}))
client, err := actions.NewClient(
ctx,
server.configURLForOrg("my-org"),
auth,
actions.WithRetryMax(1),

View File

@ -3,6 +3,7 @@ package actions_test
import (
"context"
"net/http"
"strings"
"testing"
"time"
@ -27,11 +28,19 @@ func TestAcquireJobs(t *testing.T) {
}
requestIDs := want
server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/acquirablejobs") {
w.Write([]byte(`{"count": 1}`))
return
}
w.Write(response)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
_, err = client.GetAcquirableJobs(ctx, 1)
require.NoError(t, err)
got, err := client.AcquireJobs(ctx, session.RunnerScaleSet.Id, session.MessageQueueAccessToken, requestIDs)
@ -50,13 +59,17 @@ func TestAcquireJobs(t *testing.T) {
actualRetry := 0
expectedRetry := retryMax + 1
server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/acquirablejobs") {
w.Write([]byte(`{"count": 1}`))
return
}
w.WriteHeader(http.StatusServiceUnavailable)
actualRetry++
}))
client, err := actions.NewClient(
ctx,
server.configURLForOrg("my-org"),
auth,
actions.WithRetryMax(retryMax),
@ -64,6 +77,9 @@ func TestAcquireJobs(t *testing.T) {
)
require.NoError(t, err)
_, err = client.GetAcquirableJobs(ctx, 1)
require.NoError(t, err)
_, err = client.AcquireJobs(context.Background(), session.RunnerScaleSet.Id, session.MessageQueueAccessToken, requestIDs)
assert.NotNil(t, err)
assert.Equalf(t, actualRetry, expectedRetry, "A retry was expected after the first request but got: %v", actualRetry)
@ -71,7 +87,6 @@ func TestAcquireJobs(t *testing.T) {
}
func TestGetAcquirableJobs(t *testing.T) {
ctx := context.Background()
auth := &actions.ActionsAuth{
Token: "token",
}
@ -86,7 +101,7 @@ func TestGetAcquirableJobs(t *testing.T) {
w.Write(response)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
got, err := client.GetAcquirableJobs(context.Background(), runnerScaleSet.Id)
@ -108,7 +123,6 @@ func TestGetAcquirableJobs(t *testing.T) {
}))
client, err := actions.NewClient(
context.Background(),
server.configURLForOrg("my-org"),
auth,
actions.WithRetryMax(retryMax),

View File

@ -32,7 +32,7 @@ func TestGetMessage(t *testing.T) {
w.Write(response)
}))
client, err := actions.NewClient(ctx, s.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(s.configURLForOrg("my-org"), auth)
require.NoError(t, err)
got, err := client.GetMessage(ctx, s.URL, token, 0)
@ -40,6 +40,23 @@ func TestGetMessage(t *testing.T) {
assert.Equal(t, want, got)
})
t.Run("GetMessage sets the last message id if not 0", func(t *testing.T) {
want := runnerScaleSetMessage
response := []byte(`{"messageId":1,"messageType":"rssType"}`)
s := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
assert.Equal(t, "1", q.Get("lastMessageId"))
w.Write(response)
}))
client, err := actions.NewClient(s.configURLForOrg("my-org"), auth)
require.NoError(t, err)
got, err := client.GetMessage(ctx, s.URL, token, 1)
require.NoError(t, err)
assert.Equal(t, want, got)
})
t.Run("Default retries on server error", func(t *testing.T) {
retryMax := 1
@ -52,7 +69,6 @@ func TestGetMessage(t *testing.T) {
}))
client, err := actions.NewClient(
ctx,
server.configURLForOrg("my-org"),
auth,
actions.WithRetryMax(retryMax),
@ -70,7 +86,7 @@ func TestGetMessage(t *testing.T) {
w.WriteHeader(http.StatusUnauthorized)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
_, err = client.GetMessage(ctx, server.URL, token, 0)
@ -78,8 +94,7 @@ func TestGetMessage(t *testing.T) {
var expectedErr *actions.MessageQueueTokenExpiredError
require.True(t, errors.As(err, &expectedErr))
},
)
})
t.Run("Status code not found", func(t *testing.T) {
want := actions.ActionsError{
@ -90,7 +105,7 @@ func TestGetMessage(t *testing.T) {
w.WriteHeader(http.StatusNotFound)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
_, err = client.GetMessage(ctx, server.URL, token, 0)
@ -104,7 +119,7 @@ func TestGetMessage(t *testing.T) {
w.Header().Set("Content-Type", "text/plain")
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
_, err = client.GetMessage(ctx, server.URL, token, 0)
@ -129,7 +144,7 @@ func TestDeleteMessage(t *testing.T) {
w.WriteHeader(http.StatusNoContent)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
err = client.DeleteMessage(ctx, server.URL, token, runnerScaleSetMessage.MessageId)
@ -141,7 +156,7 @@ func TestDeleteMessage(t *testing.T) {
w.WriteHeader(http.StatusUnauthorized)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
err = client.DeleteMessage(ctx, server.URL, token, 0)
@ -156,7 +171,7 @@ func TestDeleteMessage(t *testing.T) {
w.Header().Set("Content-Type", "text/plain")
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
err = client.DeleteMessage(ctx, server.URL, token, runnerScaleSetMessage.MessageId)
@ -175,7 +190,6 @@ func TestDeleteMessage(t *testing.T) {
retryMax := 1
client, err := actions.NewClient(
ctx,
server.configURLForOrg("my-org"),
auth,
actions.WithRetryMax(retryMax),
@ -197,7 +211,7 @@ func TestDeleteMessage(t *testing.T) {
w.Write(rsl)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
err = client.DeleteMessage(ctx, server.URL, token, runnerScaleSetMessage.MessageId+1)

View File

@ -51,7 +51,7 @@ func TestCreateMessageSession(t *testing.T) {
w.Write(resp)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
got, err := client.CreateMessageSession(ctx, runnerScaleSet.Id, owner)
@ -81,7 +81,7 @@ func TestCreateMessageSession(t *testing.T) {
w.Write(resp)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
_, err = client.CreateMessageSession(ctx, runnerScaleSet.Id, owner)
@ -120,7 +120,6 @@ func TestCreateMessageSession(t *testing.T) {
wantRetries := retryMax + 1
client, err := actions.NewClient(
ctx,
server.configURLForOrg("my-org"),
auth,
actions.WithRetryMax(retryMax),
@ -160,7 +159,6 @@ func TestDeleteMessageSession(t *testing.T) {
wantRetries := retryMax + 1
client, err := actions.NewClient(
ctx,
server.configURLForOrg("my-org"),
auth,
actions.WithRetryMax(retryMax),
@ -177,7 +175,6 @@ func TestDeleteMessageSession(t *testing.T) {
}
func TestRefreshMessageSession(t *testing.T) {
ctx := context.Background()
auth := &actions.ActionsAuth{
Token: "token",
}
@ -202,7 +199,6 @@ func TestRefreshMessageSession(t *testing.T) {
wantRetries := retryMax + 1
client, err := actions.NewClient(
ctx,
server.configURLForOrg("my-org"),
auth,
actions.WithRetryMax(retryMax),

View File

@ -31,7 +31,7 @@ func TestGetRunnerScaleSet(t *testing.T) {
w.Write(runnerScaleSetsResp)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
got, err := client.GetRunnerScaleSet(ctx, scaleSetName)
@ -47,15 +47,16 @@ func TestGetRunnerScaleSet(t *testing.T) {
url = *r.URL
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
_, err = client.GetRunnerScaleSet(ctx, scaleSetName)
require.NoError(t, err)
u := url.String()
expectedUrl := fmt.Sprintf("/_apis/runtime/runnerscalesets?name=%s&api-version=6.0-preview", scaleSetName)
assert.Equal(t, expectedUrl, u)
expectedPath := "/tenant/123/_apis/runtime/runnerscalesets"
assert.Equal(t, expectedPath, url.Path)
assert.Equal(t, scaleSetName, url.Query().Get("name"))
assert.Equal(t, "6.0-preview", url.Query().Get("api-version"))
})
t.Run("Status code not found", func(t *testing.T) {
@ -63,7 +64,7 @@ func TestGetRunnerScaleSet(t *testing.T) {
w.WriteHeader(http.StatusNotFound)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
_, err = client.GetRunnerScaleSet(ctx, scaleSetName)
@ -76,7 +77,7 @@ func TestGetRunnerScaleSet(t *testing.T) {
w.Header().Set("Content-Type", "text/plain")
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
_, err = client.GetRunnerScaleSet(ctx, scaleSetName)
@ -94,7 +95,6 @@ func TestGetRunnerScaleSet(t *testing.T) {
retryWaitMax := 1 * time.Microsecond
client, err := actions.NewClient(
ctx,
server.configURLForOrg("my-org"),
auth,
actions.WithRetryMax(retryMax),
@ -115,7 +115,7 @@ func TestGetRunnerScaleSet(t *testing.T) {
w.Write(runnerScaleSetsResp)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
got, err := client.GetRunnerScaleSet(ctx, scaleSetName)
@ -130,7 +130,7 @@ func TestGetRunnerScaleSet(t *testing.T) {
w.Write(runnerScaleSetsResp)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
_, err = client.GetRunnerScaleSet(ctx, scaleSetName)
@ -156,7 +156,7 @@ func TestGetRunnerScaleSetById(t *testing.T) {
w.Write(rsl)
}))
client, err := actions.NewClient(ctx, sservere.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(sservere.configURLForOrg("my-org"), auth)
require.NoError(t, err)
got, err := client.GetRunnerScaleSetById(ctx, runnerScaleSet.Id)
@ -174,15 +174,15 @@ func TestGetRunnerScaleSetById(t *testing.T) {
url = *r.URL
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
_, err = client.GetRunnerScaleSetById(ctx, runnerScaleSet.Id)
require.NoError(t, err)
u := url.String()
expectedUrl := fmt.Sprintf("/_apis/runtime/runnerscalesets/%d?api-version=6.0-preview", runnerScaleSet.Id)
assert.Equal(t, expectedUrl, u)
expectedPath := fmt.Sprintf("/tenant/123/_apis/runtime/runnerscalesets/%d", runnerScaleSet.Id)
assert.Equal(t, expectedPath, url.Path)
assert.Equal(t, "6.0-preview", url.Query().Get("api-version"))
})
t.Run("Status code not found", func(t *testing.T) {
@ -190,7 +190,7 @@ func TestGetRunnerScaleSetById(t *testing.T) {
w.WriteHeader(http.StatusNotFound)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
_, err = client.GetRunnerScaleSetById(ctx, runnerScaleSet.Id)
@ -203,7 +203,7 @@ func TestGetRunnerScaleSetById(t *testing.T) {
w.Header().Set("Content-Type", "text/plain")
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
_, err = client.GetRunnerScaleSetById(ctx, runnerScaleSet.Id)
@ -220,7 +220,6 @@ func TestGetRunnerScaleSetById(t *testing.T) {
retryMax := 1
retryWaitMax := 1 * time.Microsecond
client, err := actions.NewClient(
ctx,
server.configURLForOrg("my-org"),
auth,
actions.WithRetryMax(retryMax),
@ -242,7 +241,7 @@ func TestGetRunnerScaleSetById(t *testing.T) {
w.Write(rsl)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
got, err := client.GetRunnerScaleSetById(ctx, runnerScaleSet.Id)
@ -268,7 +267,7 @@ func TestCreateRunnerScaleSet(t *testing.T) {
w.Write(rsl)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
got, err := client.CreateRunnerScaleSet(ctx, &runnerScaleSet)
@ -285,15 +284,15 @@ func TestCreateRunnerScaleSet(t *testing.T) {
url = *r.URL
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
_, err = client.CreateRunnerScaleSet(ctx, &runnerScaleSet)
require.NoError(t, err)
u := url.String()
expectedUrl := "/_apis/runtime/runnerscalesets?api-version=6.0-preview"
assert.Equal(t, expectedUrl, u)
expectedPath := "/tenant/123/_apis/runtime/runnerscalesets"
assert.Equal(t, expectedPath, url.Path)
assert.Equal(t, "6.0-preview", url.Query().Get("api-version"))
})
t.Run("Error when Content-Type is text/plain", func(t *testing.T) {
@ -302,7 +301,7 @@ func TestCreateRunnerScaleSet(t *testing.T) {
w.Header().Set("Content-Type", "text/plain")
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
_, err = client.CreateRunnerScaleSet(ctx, &runnerScaleSet)
@ -322,7 +321,6 @@ func TestCreateRunnerScaleSet(t *testing.T) {
retryWaitMax := 1 * time.Microsecond
client, err := actions.NewClient(
ctx,
server.configURLForOrg("my-org"),
auth,
actions.WithRetryMax(retryMax),
@ -354,7 +352,7 @@ func TestUpdateRunnerScaleSet(t *testing.T) {
w.Write(rsl)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
got, err := client.UpdateRunnerScaleSet(ctx, 1, &actions.RunnerScaleSet{RunnerGroupId: 1})
@ -365,24 +363,19 @@ func TestUpdateRunnerScaleSet(t *testing.T) {
t.Run("UpdateRunnerScaleSet calls correct url", func(t *testing.T) {
rsl, err := json.Marshal(&runnerScaleSet)
require.NoError(t, err)
url := url.URL{}
method := ""
server := newActionsServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expectedPath := "/tenant/123/_apis/runtime/runnerscalesets/1"
assert.Equal(t, expectedPath, r.URL.Path)
assert.Equal(t, http.MethodPatch, r.Method)
assert.Equal(t, "6.0-preview", r.URL.Query().Get("api-version"))
w.Write(rsl)
url = *r.URL
method = r.Method
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
_, err = client.UpdateRunnerScaleSet(ctx, 1, &runnerScaleSet)
require.NoError(t, err)
u := url.String()
expectedUrl := "/_apis/runtime/runnerscalesets/1?api-version=6.0-preview"
assert.Equal(t, expectedUrl, u)
assert.Equal(t, "PATCH", method)
})
}

View File

@ -29,7 +29,7 @@ func TestGetRunner(t *testing.T) {
w.Write(response)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
got, err := client.GetRunner(ctx, runnerID)
@ -50,7 +50,7 @@ func TestGetRunner(t *testing.T) {
actualRetry++
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), actions.WithRetryWaitMax(retryWaitMax))
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), actions.WithRetryWaitMax(retryWaitMax))
require.NoError(t, err)
_, err = client.GetRunner(ctx, runnerID)
@ -78,7 +78,7 @@ func TestGetRunnerByName(t *testing.T) {
w.Write(response)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
got, err := client.GetRunnerByName(ctx, runnerName)
@ -94,7 +94,7 @@ func TestGetRunnerByName(t *testing.T) {
w.Write(response)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
got, err := client.GetRunnerByName(ctx, runnerName)
@ -116,7 +116,7 @@ func TestGetRunnerByName(t *testing.T) {
actualRetry++
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), actions.WithRetryWaitMax(retryWaitMax))
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth, actions.WithRetryMax(retryMax), actions.WithRetryWaitMax(retryWaitMax))
require.NoError(t, err)
_, err = client.GetRunnerByName(ctx, runnerName)
@ -138,7 +138,7 @@ func TestDeleteRunner(t *testing.T) {
w.WriteHeader(http.StatusNoContent)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
err = client.RemoveRunner(ctx, runnerID)
@ -160,7 +160,6 @@ func TestDeleteRunner(t *testing.T) {
}))
client, err := actions.NewClient(
ctx,
server.configURLForOrg("my-org"),
auth,
actions.WithRetryMax(retryMax),
@ -193,7 +192,7 @@ func TestGetRunnerGroupByName(t *testing.T) {
w.Write(response)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
got, err := client.GetRunnerGroupByName(ctx, runnerGroupName)
@ -209,7 +208,7 @@ func TestGetRunnerGroupByName(t *testing.T) {
w.Write(response)
}))
client, err := actions.NewClient(ctx, server.configURLForOrg("my-org"), auth)
client, err := actions.NewClient(server.configURLForOrg("my-org"), auth)
require.NoError(t, err)
got, err := client.GetRunnerGroupByName(ctx, runnerGroupName)

View File

@ -22,9 +22,9 @@ import (
func TestServerWithSelfSignedCertificates(t *testing.T) {
ctx := context.Background()
// this handler is a very very barebones replica of actions api
// used during the creation of a a new client
var u string
h := func(w http.ResponseWriter, r *http.Request) {
// handle get registration token
if strings.HasSuffix(r.URL.Path, "/runners/registration-token") {
@ -46,9 +46,12 @@ func TestServerWithSelfSignedCertificates(t *testing.T) {
require.NoError(t, err)
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
w.Write([]byte(`{"url":"TODO","token":"` + tokenString + `"}`))
w.Write([]byte(`{"url":"` + u + `","token":"` + tokenString + `"}`))
return
}
// default happy response for RemoveRunner
w.WriteHeader(http.StatusNoContent)
}
certPath := filepath.Join("testdata", "server.crt")
@ -56,13 +59,17 @@ func TestServerWithSelfSignedCertificates(t *testing.T) {
t.Run("client without ca certs", func(t *testing.T) {
server := startNewTLSTestServer(t, certPath, keyPath, http.HandlerFunc(h))
u = server.URL
configURL := server.URL + "/my-org"
auth := &actions.ActionsAuth{
Token: "token",
}
client, err := actions.NewClient(ctx, configURL, auth)
assert.Nil(t, client)
client, err := actions.NewClient(configURL, auth)
require.NoError(t, err)
require.NotNil(t, client)
err = client.RemoveRunner(ctx, 1)
require.NotNil(t, err)
if runtime.GOOS == "linux" {
@ -78,6 +85,7 @@ func TestServerWithSelfSignedCertificates(t *testing.T) {
t.Run("client with ca certs", func(t *testing.T) {
server := startNewTLSTestServer(t, certPath, keyPath, http.HandlerFunc(h))
u = server.URL
configURL := server.URL + "/my-org"
auth := &actions.ActionsAuth{
@ -90,9 +98,12 @@ func TestServerWithSelfSignedCertificates(t *testing.T) {
pool, err := actions.RootCAsFromConfigMap(map[string][]byte{"cert": cert})
require.NoError(t, err)
client, err := actions.NewClient(ctx, configURL, auth, actions.WithRootCAs(pool))
client, err := actions.NewClient(configURL, auth, actions.WithRootCAs(pool))
require.NoError(t, err)
assert.NotNil(t, client)
err = client.RemoveRunner(ctx, 1)
assert.NoError(t, err)
})
t.Run("client with ca chain certs", func(t *testing.T) {
@ -102,6 +113,7 @@ func TestServerWithSelfSignedCertificates(t *testing.T) {
filepath.Join("testdata", "leaf.key"),
http.HandlerFunc(h),
)
u = server.URL
configURL := server.URL + "/my-org"
auth := &actions.ActionsAuth{
@ -114,9 +126,12 @@ func TestServerWithSelfSignedCertificates(t *testing.T) {
pool, err := actions.RootCAsFromConfigMap(map[string][]byte{"cert": cert})
require.NoError(t, err)
client, err := actions.NewClient(ctx, configURL, auth, actions.WithRootCAs(pool), actions.WithRetryMax(0))
client, err := actions.NewClient(configURL, auth, actions.WithRootCAs(pool), actions.WithRetryMax(0))
require.NoError(t, err)
assert.NotNil(t, client)
require.NotNil(t, client)
err = client.RemoveRunner(ctx, 1)
assert.NoError(t, err)
})
t.Run("client skipping tls verification", func(t *testing.T) {
@ -127,7 +142,7 @@ func TestServerWithSelfSignedCertificates(t *testing.T) {
Token: "token",
}
client, err := actions.NewClient(ctx, configURL, auth, actions.WithoutTLSVerify())
client, err := actions.NewClient(configURL, auth, actions.WithoutTLSVerify())
require.NoError(t, err)
assert.NotNil(t, client)
})

98
github/actions/config.go Normal file
View File

@ -0,0 +1,98 @@
package actions
import (
"fmt"
"net/url"
"strings"
)
var ErrInvalidGitHubConfigURL = fmt.Errorf("invalid config URL, should point to an enterprise, org, or repository")
type GitHubScope int
const (
GitHubScopeUnknown GitHubScope = iota
GitHubScopeEnterprise
GitHubScopeOrganization
GitHubScopeRepository
)
type GitHubConfig struct {
ConfigURL *url.URL
Scope GitHubScope
Enterprise string
Organization string
Repository string
IsHosted bool
}
func ParseGitHubConfigFromURL(in string) (*GitHubConfig, error) {
u, err := url.Parse(in)
if err != nil {
return nil, err
}
isHosted := u.Host == "github.com" ||
u.Host == "www.github.com" ||
u.Host == "github.localhost"
configURL := &GitHubConfig{
ConfigURL: u,
IsHosted: isHosted,
}
invalidURLError := fmt.Errorf("%q: %w", u.String(), ErrInvalidGitHubConfigURL)
pathParts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
switch len(pathParts) {
case 1: // Organization
if pathParts[0] == "" {
return nil, invalidURLError
}
configURL.Scope = GitHubScopeOrganization
configURL.Organization = pathParts[0]
case 2: // Repository or enterprise
if strings.ToLower(pathParts[0]) == "enterprises" {
configURL.Scope = GitHubScopeEnterprise
configURL.Enterprise = pathParts[1]
break
}
configURL.Scope = GitHubScopeRepository
configURL.Organization = pathParts[0]
configURL.Repository = pathParts[1]
default:
return nil, invalidURLError
}
return configURL, nil
}
func (c *GitHubConfig) GitHubAPIURL(path string) *url.URL {
result := &url.URL{
Scheme: c.ConfigURL.Scheme,
}
switch c.ConfigURL.Host {
// Hosted
case "github.com", "github.localhost":
result.Host = fmt.Sprintf("api.%s", c.ConfigURL.Host)
// re-routing www.github.com to api.github.com
case "www.github.com":
result.Host = "api.github.com"
// Enterprise
default:
result.Host = c.ConfigURL.Host
result.Path = "/api/v3"
}
result.Path += path
return result
}

View File

@ -0,0 +1,117 @@
package actions_test
import (
"errors"
"net/url"
"testing"
"github.com/actions/actions-runner-controller/github/actions"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGitHubConfig(t *testing.T) {
t.Run("when given a valid URL", func(t *testing.T) {
tests := []struct {
configURL string
expected *actions.GitHubConfig
}{
{
configURL: "https://github.com/org/repo",
expected: &actions.GitHubConfig{
Scope: actions.GitHubScopeRepository,
Enterprise: "",
Organization: "org",
Repository: "repo",
IsHosted: true,
},
},
{
configURL: "https://github.com/org",
expected: &actions.GitHubConfig{
Scope: actions.GitHubScopeOrganization,
Enterprise: "",
Organization: "org",
Repository: "",
IsHosted: true,
},
},
{
configURL: "https://github.com/enterprises/my-enterprise",
expected: &actions.GitHubConfig{
Scope: actions.GitHubScopeEnterprise,
Enterprise: "my-enterprise",
Organization: "",
Repository: "",
IsHosted: true,
},
},
{
configURL: "https://www.github.com/org",
expected: &actions.GitHubConfig{
Scope: actions.GitHubScopeOrganization,
Enterprise: "",
Organization: "org",
Repository: "",
IsHosted: true,
},
},
{
configURL: "https://github.localhost/org",
expected: &actions.GitHubConfig{
Scope: actions.GitHubScopeOrganization,
Enterprise: "",
Organization: "org",
Repository: "",
IsHosted: true,
},
},
{
configURL: "https://my-ghes.com/org",
expected: &actions.GitHubConfig{
Scope: actions.GitHubScopeOrganization,
Enterprise: "",
Organization: "org",
Repository: "",
IsHosted: false,
},
},
}
for _, test := range tests {
t.Run(test.configURL, func(t *testing.T) {
parsedURL, err := url.Parse(test.configURL)
require.NoError(t, err)
test.expected.ConfigURL = parsedURL
cfg, err := actions.ParseGitHubConfigFromURL(test.configURL)
require.NoError(t, err)
assert.Equal(t, test.expected, cfg)
})
}
})
t.Run("when given an invalid URL", func(t *testing.T) {})
invalidURLs := []string{
"https://github.com/",
"https://github.com",
"https://github.com/some/random/path",
}
for _, u := range invalidURLs {
_, err := actions.ParseGitHubConfigFromURL(u)
require.Error(t, err)
assert.True(t, errors.Is(err, actions.ErrInvalidGitHubConfigURL))
}
}
func TestGitHubConfig_GitHubAPIURL(t *testing.T) {
t.Run("when hosted", func(t *testing.T) {
config, err := actions.ParseGitHubConfigFromURL("https://github.com/org/repo")
require.NoError(t, err)
result := config.GitHubAPIURL("/some/path")
assert.Equal(t, "https://api.github.com/some/path", result.String())
})
t.Run("when not hosted", func(t *testing.T) {})
}

View File

@ -0,0 +1,171 @@
package actions_test
import (
"context"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"
"github.com/actions/actions-runner-controller/github/actions"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewGitHubAPIRequest(t *testing.T) {
ctx := context.Background()
t.Run("uses the right host/path prefix", func(t *testing.T) {
scenarios := []struct {
configURL string
path string
expected string
}{
{
configURL: "https://github.com/org/repo",
path: "/app/installations/123/access_tokens",
expected: "https://api.github.com/app/installations/123/access_tokens",
},
{
configURL: "https://www.github.com/org/repo",
path: "/app/installations/123/access_tokens",
expected: "https://api.github.com/app/installations/123/access_tokens",
},
{
configURL: "http://github.localhost/org/repo",
path: "/app/installations/123/access_tokens",
expected: "http://api.github.localhost/app/installations/123/access_tokens",
},
{
configURL: "https://my-instance.com/org/repo",
path: "/app/installations/123/access_tokens",
expected: "https://my-instance.com/api/v3/app/installations/123/access_tokens",
},
{
configURL: "http://localhost/org/repo",
path: "/app/installations/123/access_tokens",
expected: "http://localhost/api/v3/app/installations/123/access_tokens",
},
}
for _, scenario := range scenarios {
client, err := actions.NewClient(scenario.configURL, nil)
require.NoError(t, err)
req, err := client.NewGitHubAPIRequest(ctx, http.MethodGet, scenario.path, nil)
require.NoError(t, err)
assert.Equal(t, scenario.expected, req.URL.String())
}
})
t.Run("sets user agent header if present", func(t *testing.T) {
client, err := actions.NewClient("http://localhost/my-org", nil, actions.WithUserAgent("my-agent"))
require.NoError(t, err)
req, err := client.NewGitHubAPIRequest(ctx, http.MethodGet, "/app/installations/123/access_tokens", nil)
require.NoError(t, err)
assert.Equal(t, "my-agent", req.Header.Get("User-Agent"))
})
t.Run("sets the body we pass", func(t *testing.T) {
client, err := actions.NewClient("http://localhost/my-org", nil)
require.NoError(t, err)
req, err := client.NewGitHubAPIRequest(
ctx,
http.MethodGet,
"/app/installations/123/access_tokens",
strings.NewReader("the-body"),
)
require.NoError(t, err)
b, err := io.ReadAll(req.Body)
require.NoError(t, err)
assert.Equal(t, "the-body", string(b))
})
}
func TestNewActionsServiceRequest(t *testing.T) {
ctx := context.Background()
defaultCreds := &actions.ActionsAuth{Token: "token"}
t.Run("manages authentication", func(t *testing.T) {
t.Run("client is brand new", func(t *testing.T) {
token := defaultActionsToken(t)
server := newActionsServer(t, nil, withActionsToken(token))
client, err := actions.NewClient(server.configURLForOrg("my-org"), defaultCreds)
require.NoError(t, err)
req, err := client.NewActionsServiceRequest(ctx, http.MethodGet, "my-path", nil)
require.NoError(t, err)
assert.Equal(t, "Bearer "+token, req.Header.Get("Authorization"))
})
t.Run("admin token is about to expire", func(t *testing.T) {
newToken := defaultActionsToken(t)
server := newActionsServer(t, nil, withActionsToken(newToken))
client, err := actions.NewClient(server.configURLForOrg("my-org"), defaultCreds)
require.NoError(t, err)
client.ActionsServiceAdminToken = "expiring-token"
client.ActionsServiceAdminTokenExpiresAt = time.Now().Add(59 * time.Second)
req, err := client.NewActionsServiceRequest(ctx, http.MethodGet, "my-path", nil)
require.NoError(t, err)
assert.Equal(t, "Bearer "+newToken, req.Header.Get("Authorization"))
})
t.Run("token is currently valid", func(t *testing.T) {
tokenThatShouldNotBeFetched := defaultActionsToken(t)
server := newActionsServer(t, nil, withActionsToken(tokenThatShouldNotBeFetched))
client, err := actions.NewClient(server.configURLForOrg("my-org"), defaultCreds)
require.NoError(t, err)
client.ActionsServiceAdminToken = "healthy-token"
client.ActionsServiceAdminTokenExpiresAt = time.Now().Add(1 * time.Hour)
req, err := client.NewActionsServiceRequest(ctx, http.MethodGet, "my-path", nil)
require.NoError(t, err)
assert.Equal(t, "Bearer healthy-token", req.Header.Get("Authorization"))
})
})
t.Run("builds the right URL including api version", func(t *testing.T) {
server := newActionsServer(t, nil)
client, err := actions.NewClient(server.configURLForOrg("my-org"), defaultCreds)
require.NoError(t, err)
req, err := client.NewActionsServiceRequest(ctx, http.MethodGet, "/my/path?name=banana", nil)
require.NoError(t, err)
serverURL, err := url.Parse(server.URL)
require.NoError(t, err)
result := req.URL
assert.Equal(t, serverURL.Host, result.Host)
assert.Equal(t, "/tenant/123/my/path", result.Path)
assert.Equal(t, "banana", result.Query().Get("name"))
assert.Equal(t, "6.0-preview", result.Query().Get("api-version"))
})
t.Run("populates header", func(t *testing.T) {
server := newActionsServer(t, nil)
client, err := actions.NewClient(server.configURLForOrg("my-org"), defaultCreds, actions.WithUserAgent("my-agent"))
require.NoError(t, err)
req, err := client.NewActionsServiceRequest(ctx, http.MethodGet, "/my/path", nil)
require.NoError(t, err)
assert.Equal(t, "my-agent", req.Header.Get("User-Agent"))
assert.Equal(t, "application/json", req.Header.Get("Content-Type"))
})
}

View File

@ -106,7 +106,6 @@ func (m *multiClient) GetClientFor(ctx context.Context, githubConfigURL string,
m.logger.Info("creating new client", "githubConfigURL", githubConfigURL, "namespace", namespace)
client, err := NewClient(
ctx,
githubConfigURL,
&creds,
WithUserAgent(m.userAgent),

View File

@ -1,48 +0,0 @@
package actions
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGithubAPIURL(t *testing.T) {
tests := []struct {
configURL string
path string
expected string
}{
{
configURL: "https://github.com/org/repo",
path: "/app/installations/123/access_tokens",
expected: "https://api.github.com/app/installations/123/access_tokens",
},
{
configURL: "https://www.github.com/org/repo",
path: "/app/installations/123/access_tokens",
expected: "https://api.github.com/app/installations/123/access_tokens",
},
{
configURL: "http://github.localhost/org/repo",
path: "/app/installations/123/access_tokens",
expected: "http://api.github.localhost/app/installations/123/access_tokens",
},
{
configURL: "https://my-instance.com/org/repo",
path: "/app/installations/123/access_tokens",
expected: "https://my-instance.com/api/v3/app/installations/123/access_tokens",
},
{
configURL: "http://localhost/org/repo",
path: "/app/installations/123/access_tokens",
expected: "http://localhost/api/v3/app/installations/123/access_tokens",
},
}
for _, test := range tests {
actual, err := githubAPIURL(test.configURL, test.path)
require.NoError(t, err)
assert.Equal(t, test.expected, actual)
}
}