1093 lines
29 KiB
Go
1093 lines
29 KiB
Go
package actions
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/actions/actions-runner-controller/build"
|
|
"github.com/go-logr/logr"
|
|
"github.com/golang-jwt/jwt/v4"
|
|
"github.com/google/uuid"
|
|
"github.com/hashicorp/go-retryablehttp"
|
|
)
|
|
|
|
const (
|
|
runnerEndpoint = "_apis/distributedtask/pools/0/agents"
|
|
scaleSetEndpoint = "_apis/runtime/runnerscalesets"
|
|
apiVersionQueryParam = "api-version=6.0-preview"
|
|
)
|
|
|
|
//go:generate mockery --inpackage --name=ActionsService
|
|
type ActionsService interface {
|
|
GetRunnerScaleSet(ctx context.Context, runnerGroupId int, runnerScaleSetName string) (*RunnerScaleSet, error)
|
|
GetRunnerScaleSetById(ctx context.Context, runnerScaleSetId int) (*RunnerScaleSet, error)
|
|
GetRunnerGroupByName(ctx context.Context, runnerGroup string) (*RunnerGroup, error)
|
|
CreateRunnerScaleSet(ctx context.Context, runnerScaleSet *RunnerScaleSet) (*RunnerScaleSet, error)
|
|
UpdateRunnerScaleSet(ctx context.Context, runnerScaleSetId int, runnerScaleSet *RunnerScaleSet) (*RunnerScaleSet, error)
|
|
DeleteRunnerScaleSet(ctx context.Context, runnerScaleSetId int) error
|
|
|
|
CreateMessageSession(ctx context.Context, runnerScaleSetId int, owner string) (*RunnerScaleSetSession, error)
|
|
DeleteMessageSession(ctx context.Context, runnerScaleSetId int, sessionId *uuid.UUID) error
|
|
RefreshMessageSession(ctx context.Context, runnerScaleSetId int, sessionId *uuid.UUID) (*RunnerScaleSetSession, error)
|
|
|
|
AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQueueAccessToken string, requestIds []int64) ([]int64, error)
|
|
GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (*AcquirableJobList, error)
|
|
|
|
GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64) (*RunnerScaleSetMessage, error)
|
|
DeleteMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, messageId int64) error
|
|
|
|
GenerateJitRunnerConfig(ctx context.Context, jitRunnerSetting *RunnerScaleSetJitRunnerSetting, scaleSetId int) (*RunnerScaleSetJitRunnerConfig, error)
|
|
|
|
GetRunner(ctx context.Context, runnerId int64) (*RunnerReference, error)
|
|
GetRunnerByName(ctx context.Context, runnerName string) (*RunnerReference, error)
|
|
RemoveRunner(ctx context.Context, runnerId int64) error
|
|
|
|
SetUserAgent(info UserAgentInfo)
|
|
}
|
|
|
|
type clientLogger struct {
|
|
logr.Logger
|
|
}
|
|
|
|
func (l *clientLogger) Info(msg string, keysAndValues ...interface{}) {
|
|
l.Logger.Info(msg, keysAndValues...)
|
|
}
|
|
|
|
func (l *clientLogger) Debug(msg string, keysAndValues ...interface{}) {
|
|
// discard debug log
|
|
}
|
|
|
|
func (l *clientLogger) Error(msg string, keysAndValues ...interface{}) {
|
|
l.Logger.Error(errors.New(msg), "Retryable client error", keysAndValues...)
|
|
}
|
|
|
|
func (l *clientLogger) Warn(msg string, keysAndValues ...interface{}) {
|
|
l.Logger.Info(msg, keysAndValues...)
|
|
}
|
|
|
|
var _ retryablehttp.LeveledLogger = &clientLogger{}
|
|
|
|
type Client struct {
|
|
*http.Client
|
|
|
|
// lock for refreshing the ActionsServiceAdminToken and ActionsServiceAdminTokenExpiresAt
|
|
mu sync.Mutex
|
|
|
|
// TODO: Convert to unexported fields once refactor of Listener is complete
|
|
ActionsServiceAdminToken string
|
|
ActionsServiceAdminTokenExpiresAt time.Time
|
|
ActionsServiceURL string
|
|
|
|
retryMax int
|
|
retryWaitMax time.Duration
|
|
|
|
creds *ActionsAuth
|
|
config *GitHubConfig
|
|
logger logr.Logger
|
|
userAgent UserAgentInfo
|
|
|
|
rootCAs *x509.CertPool
|
|
tlsInsecureSkipVerify bool
|
|
|
|
proxyFunc ProxyFunc
|
|
}
|
|
|
|
type ProxyFunc func(req *http.Request) (*url.URL, error)
|
|
|
|
type ClientOption func(*Client)
|
|
|
|
type UserAgentInfo struct {
|
|
// Version is the version of the controller
|
|
Version string
|
|
// CommitSHA is the git commit SHA of the controller
|
|
CommitSHA string
|
|
// ScaleSetID is the ID of the scale set
|
|
ScaleSetID int
|
|
// HasProxy is true if the controller is running behind a proxy
|
|
HasProxy bool
|
|
// Subsystem is the subsystem such as listener, controller, etc.
|
|
// Each system may pick its own subsystem name.
|
|
Subsystem string
|
|
}
|
|
|
|
func (u UserAgentInfo) String() string {
|
|
scaleSetID := "NA"
|
|
if u.ScaleSetID > 0 {
|
|
scaleSetID = strconv.Itoa(u.ScaleSetID)
|
|
}
|
|
|
|
proxy := "Proxy/disabled"
|
|
if u.HasProxy {
|
|
proxy = "Proxy/enabled"
|
|
}
|
|
|
|
return fmt.Sprintf("actions-runner-controller/%s (%s; %s) ScaleSetID/%s (%s)", u.Version, u.CommitSHA, u.Subsystem, scaleSetID, proxy)
|
|
}
|
|
|
|
func WithLogger(logger logr.Logger) ClientOption {
|
|
return func(c *Client) {
|
|
c.logger = logger
|
|
}
|
|
}
|
|
|
|
func WithRetryMax(retryMax int) ClientOption {
|
|
return func(c *Client) {
|
|
c.retryMax = retryMax
|
|
}
|
|
}
|
|
|
|
func WithRetryWaitMax(retryWaitMax time.Duration) ClientOption {
|
|
return func(c *Client) {
|
|
c.retryWaitMax = retryWaitMax
|
|
}
|
|
}
|
|
|
|
func WithRootCAs(rootCAs *x509.CertPool) ClientOption {
|
|
return func(c *Client) {
|
|
c.rootCAs = rootCAs
|
|
}
|
|
}
|
|
|
|
func WithoutTLSVerify() ClientOption {
|
|
return func(c *Client) {
|
|
c.tlsInsecureSkipVerify = true
|
|
}
|
|
}
|
|
|
|
func WithProxy(proxyFunc ProxyFunc) ClientOption {
|
|
return func(c *Client) {
|
|
c.proxyFunc = proxyFunc
|
|
}
|
|
}
|
|
|
|
func NewClient(githubConfigURL string, creds *ActionsAuth, options ...ClientOption) (*Client, error) {
|
|
config, err := ParseGitHubConfigFromURL(githubConfigURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse githubConfigURL: %w", err)
|
|
}
|
|
|
|
ac := &Client{
|
|
creds: creds,
|
|
config: config,
|
|
logger: logr.Discard(),
|
|
|
|
// retryablehttp defaults
|
|
retryMax: 4,
|
|
retryWaitMax: 30 * time.Second,
|
|
userAgent: UserAgentInfo{
|
|
Version: build.Version,
|
|
CommitSHA: build.CommitSHA,
|
|
ScaleSetID: 0,
|
|
},
|
|
}
|
|
|
|
for _, option := range options {
|
|
option(ac)
|
|
}
|
|
|
|
retryClient := retryablehttp.NewClient()
|
|
retryClient.Logger = &clientLogger{Logger: ac.logger}
|
|
|
|
retryClient.RetryMax = ac.retryMax
|
|
retryClient.RetryWaitMax = ac.retryWaitMax
|
|
|
|
retryClient.HTTPClient.Timeout = 5 * time.Minute // timeout must be > 1m to accomodate long polling
|
|
|
|
transport, ok := retryClient.HTTPClient.Transport.(*http.Transport)
|
|
if !ok {
|
|
// this should always be true, because retryablehttp.NewClient() uses
|
|
// cleanhttp.DefaultPooledTransport()
|
|
return nil, fmt.Errorf("failed to get http transport from retryablehttp client")
|
|
}
|
|
if transport.TLSClientConfig == nil {
|
|
transport.TLSClientConfig = &tls.Config{}
|
|
}
|
|
|
|
if ac.rootCAs != nil {
|
|
transport.TLSClientConfig.RootCAs = ac.rootCAs
|
|
}
|
|
|
|
if ac.tlsInsecureSkipVerify {
|
|
transport.TLSClientConfig.InsecureSkipVerify = true
|
|
}
|
|
|
|
transport.Proxy = ac.proxyFunc
|
|
|
|
retryClient.HTTPClient.Transport = transport
|
|
ac.Client = retryClient.StandardClient()
|
|
|
|
return ac, nil
|
|
}
|
|
|
|
func (c *Client) SetUserAgent(info UserAgentInfo) {
|
|
c.userAgent = info
|
|
}
|
|
|
|
// Identifier returns a string to help identify a client uniquely.
|
|
// This is used for caching client instances and understanding when a config
|
|
// change warrants creating a new client. Any changes to Client that would
|
|
// require a new client should be reflected here.
|
|
func (c *Client) Identifier() string {
|
|
identifier := fmt.Sprintf("configURL:%q,", c.config.ConfigURL.String())
|
|
|
|
if c.creds.Token != "" {
|
|
identifier += fmt.Sprintf("token:%q,", c.creds.Token)
|
|
}
|
|
|
|
if c.creds.AppCreds != nil {
|
|
identifier += fmt.Sprintf(
|
|
"appID:%q,installationID:%q,key:%q",
|
|
c.creds.AppCreds.AppID,
|
|
c.creds.AppCreds.AppInstallationID,
|
|
c.creds.AppCreds.AppPrivateKey,
|
|
)
|
|
}
|
|
|
|
if c.rootCAs != nil {
|
|
// ignoring because this cert pool is intended not to come from SystemCertPool
|
|
// nolint:staticcheck
|
|
identifier += fmt.Sprintf("rootCAs:%q", c.rootCAs.Subjects())
|
|
}
|
|
|
|
return uuid.NewHash(sha256.New(), uuid.NameSpaceOID, []byte(identifier), 6).String()
|
|
}
|
|
|
|
func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
|
resp, err := c.Client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = resp.Body.Close()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
body = trimByteOrderMark(body)
|
|
resp.Body = io.NopCloser(bytes.NewReader(body))
|
|
return resp, nil
|
|
}
|
|
|
|
func (c *Client) NewGitHubAPIRequest(ctx context.Context, method, path string, body io.Reader) (*http.Request, error) {
|
|
u := c.config.GitHubAPIURL(path)
|
|
req, err := http.NewRequestWithContext(ctx, method, u.String(), body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.Set("User-Agent", c.userAgent.String())
|
|
|
|
return req, nil
|
|
}
|
|
|
|
func (c *Client) NewActionsServiceRequest(ctx context.Context, method, path string, body io.Reader) (*http.Request, error) {
|
|
err := c.updateTokenIfNeeded(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
parsedPath, err := url.Parse(path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
urlString, err := url.JoinPath(c.ActionsServiceURL, parsedPath.Path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
u, err := url.Parse(urlString)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
q := u.Query()
|
|
for k, v := range parsedPath.Query() {
|
|
q[k] = v
|
|
}
|
|
if q.Get("api-version") == "" {
|
|
q.Set("api-version", "6.0-preview")
|
|
}
|
|
u.RawQuery = q.Encode()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, method, u.String(), body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.ActionsServiceAdminToken))
|
|
req.Header.Set("User-Agent", c.userAgent.String())
|
|
|
|
return req, nil
|
|
}
|
|
|
|
func (c *Client) GetRunnerScaleSet(ctx context.Context, runnerGroupId int, runnerScaleSetName string) (*RunnerScaleSet, error) {
|
|
path := fmt.Sprintf("/%s?runnerGroupId=%d&name=%s", scaleSetEndpoint, runnerGroupId, runnerScaleSetName)
|
|
req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, ParseActionsErrorFromResponse(resp)
|
|
}
|
|
|
|
var runnerScaleSetList *runnerScaleSetsResponse
|
|
err = json.NewDecoder(resp.Body).Decode(&runnerScaleSetList)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if runnerScaleSetList.Count == 0 {
|
|
return nil, nil
|
|
}
|
|
if runnerScaleSetList.Count > 1 {
|
|
return nil, fmt.Errorf("multiple runner scale sets found with name %s", runnerScaleSetName)
|
|
}
|
|
|
|
return &runnerScaleSetList.RunnerScaleSets[0], nil
|
|
}
|
|
|
|
func (c *Client) GetRunnerScaleSetById(ctx context.Context, runnerScaleSetId int) (*RunnerScaleSet, error) {
|
|
path := fmt.Sprintf("/%s/%d", scaleSetEndpoint, runnerScaleSetId)
|
|
req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, ParseActionsErrorFromResponse(resp)
|
|
}
|
|
|
|
var runnerScaleSet *RunnerScaleSet
|
|
err = json.NewDecoder(resp.Body).Decode(&runnerScaleSet)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return runnerScaleSet, nil
|
|
}
|
|
|
|
func (c *Client) GetRunnerGroupByName(ctx context.Context, runnerGroup string) (*RunnerGroup, error) {
|
|
path := fmt.Sprintf("/_apis/runtime/runnergroups/?groupName=%s", runnerGroup)
|
|
req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, fmt.Errorf("unexpected status code: %d - body: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var runnerGroupList *RunnerGroupList
|
|
err = json.NewDecoder(resp.Body).Decode(&runnerGroupList)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if runnerGroupList.Count == 0 {
|
|
return nil, fmt.Errorf("no runner group found with name '%s'", runnerGroup)
|
|
}
|
|
|
|
if runnerGroupList.Count > 1 {
|
|
return nil, fmt.Errorf("multiple runner group found with name %s", runnerGroup)
|
|
}
|
|
|
|
return &runnerGroupList.RunnerGroups[0], nil
|
|
}
|
|
|
|
func (c *Client) CreateRunnerScaleSet(ctx context.Context, runnerScaleSet *RunnerScaleSet) (*RunnerScaleSet, error) {
|
|
body, err := json.Marshal(runnerScaleSet)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req, err := c.NewActionsServiceRequest(ctx, http.MethodPost, scaleSetEndpoint, bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, ParseActionsErrorFromResponse(resp)
|
|
}
|
|
var createdRunnerScaleSet *RunnerScaleSet
|
|
err = json.NewDecoder(resp.Body).Decode(&createdRunnerScaleSet)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return createdRunnerScaleSet, nil
|
|
}
|
|
|
|
func (c *Client) UpdateRunnerScaleSet(ctx context.Context, runnerScaleSetId int, runnerScaleSet *RunnerScaleSet) (*RunnerScaleSet, error) {
|
|
path := fmt.Sprintf("%s/%d", scaleSetEndpoint, runnerScaleSetId)
|
|
|
|
body, err := json.Marshal(runnerScaleSet)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req, err := c.NewActionsServiceRequest(ctx, http.MethodPatch, path, bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, ParseActionsErrorFromResponse(resp)
|
|
}
|
|
|
|
var updatedRunnerScaleSet *RunnerScaleSet
|
|
err = json.NewDecoder(resp.Body).Decode(&updatedRunnerScaleSet)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return updatedRunnerScaleSet, nil
|
|
}
|
|
|
|
func (c *Client) DeleteRunnerScaleSet(ctx context.Context, runnerScaleSetId int) error {
|
|
path := fmt.Sprintf("/%s/%d", scaleSetEndpoint, runnerScaleSetId)
|
|
req, err := c.NewActionsServiceRequest(ctx, http.MethodDelete, path, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusNoContent {
|
|
return ParseActionsErrorFromResponse(resp)
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64) (*RunnerScaleSetMessage, error) {
|
|
u, err := url.Parse(messageQueueUrl)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if lastMessageId > 0 {
|
|
q := u.Query()
|
|
q.Set("lastMessageId", strconv.FormatInt(lastMessageId, 10))
|
|
u.RawQuery = q.Encode()
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.Set("Accept", "application/json; api-version=6.0-preview")
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", messageQueueAccessToken))
|
|
req.Header.Set("User-Agent", c.userAgent.String())
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode == http.StatusAccepted {
|
|
defer resp.Body.Close()
|
|
return nil, nil
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
if resp.StatusCode != http.StatusUnauthorized {
|
|
return nil, ParseActionsErrorFromResponse(resp)
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
body, err := io.ReadAll(resp.Body)
|
|
body = trimByteOrderMark(body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, &MessageQueueTokenExpiredError{msg: string(body)}
|
|
}
|
|
|
|
var message *RunnerScaleSetMessage
|
|
err = json.NewDecoder(resp.Body).Decode(&message)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return message, nil
|
|
}
|
|
|
|
func (c *Client) DeleteMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, messageId int64) error {
|
|
u, err := url.Parse(messageQueueUrl)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
u.Path = fmt.Sprintf("%s/%d", u.Path, messageId)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, u.String(), nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", messageQueueAccessToken))
|
|
req.Header.Set("User-Agent", c.userAgent.String())
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusNoContent {
|
|
if resp.StatusCode != http.StatusUnauthorized {
|
|
return ParseActionsErrorFromResponse(resp)
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
body, err := io.ReadAll(resp.Body)
|
|
body = trimByteOrderMark(body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return &MessageQueueTokenExpiredError{msg: string(body)}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) CreateMessageSession(ctx context.Context, runnerScaleSetId int, owner string) (*RunnerScaleSetSession, error) {
|
|
path := fmt.Sprintf("/%s/%d/sessions", scaleSetEndpoint, runnerScaleSetId)
|
|
|
|
newSession := &RunnerScaleSetSession{
|
|
OwnerName: owner,
|
|
}
|
|
|
|
requestData, err := json.Marshal(newSession)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
createdSession := &RunnerScaleSetSession{}
|
|
|
|
err = c.doSessionRequest(ctx, http.MethodPost, path, bytes.NewBuffer(requestData), http.StatusOK, createdSession)
|
|
|
|
return createdSession, err
|
|
}
|
|
|
|
func (c *Client) DeleteMessageSession(ctx context.Context, runnerScaleSetId int, sessionId *uuid.UUID) error {
|
|
path := fmt.Sprintf("/%s/%d/sessions/%s", scaleSetEndpoint, runnerScaleSetId, sessionId.String())
|
|
return c.doSessionRequest(ctx, http.MethodDelete, path, nil, http.StatusNoContent, nil)
|
|
}
|
|
|
|
func (c *Client) RefreshMessageSession(ctx context.Context, runnerScaleSetId int, sessionId *uuid.UUID) (*RunnerScaleSetSession, error) {
|
|
path := fmt.Sprintf("/%s/%d/sessions/%s", scaleSetEndpoint, runnerScaleSetId, sessionId.String())
|
|
refreshedSession := &RunnerScaleSetSession{}
|
|
err := c.doSessionRequest(ctx, http.MethodPatch, path, nil, http.StatusOK, refreshedSession)
|
|
return refreshedSession, err
|
|
}
|
|
|
|
func (c *Client) doSessionRequest(ctx context.Context, method, path string, requestData io.Reader, expectedResponseStatusCode int, responseUnmarshalTarget any) error {
|
|
req, err := c.NewActionsServiceRequest(ctx, method, path, requestData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if resp.StatusCode == expectedResponseStatusCode && responseUnmarshalTarget != nil {
|
|
return json.NewDecoder(resp.Body).Decode(responseUnmarshalTarget)
|
|
}
|
|
|
|
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
|
|
return ParseActionsErrorFromResponse(resp)
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
body, err := io.ReadAll(resp.Body)
|
|
body = trimByteOrderMark(body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return fmt.Errorf("unexpected status code: %d - body: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
func (c *Client) AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQueueAccessToken string, requestIds []int64) ([]int64, error) {
|
|
u := fmt.Sprintf("%s/%s/%d/acquirejobs?api-version=6.0-preview", c.ActionsServiceURL, scaleSetEndpoint, runnerScaleSetId)
|
|
|
|
body, err := json.Marshal(requestIds)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u, bytes.NewBuffer(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", messageQueueAccessToken))
|
|
req.Header.Set("User-Agent", c.userAgent.String())
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
if resp.StatusCode != http.StatusUnauthorized {
|
|
return nil, ParseActionsErrorFromResponse(resp)
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
body, err := io.ReadAll(resp.Body)
|
|
body = trimByteOrderMark(body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return nil, &MessageQueueTokenExpiredError{msg: string(body)}
|
|
}
|
|
|
|
var acquiredJobs *Int64List
|
|
err = json.NewDecoder(resp.Body).Decode(&acquiredJobs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return acquiredJobs.Value, nil
|
|
}
|
|
|
|
func (c *Client) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (*AcquirableJobList, error) {
|
|
path := fmt.Sprintf("/%s/%d/acquirablejobs", scaleSetEndpoint, runnerScaleSetId)
|
|
|
|
req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode == http.StatusNoContent {
|
|
defer resp.Body.Close()
|
|
return &AcquirableJobList{Count: 0, Jobs: []AcquirableJob{}}, nil
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, ParseActionsErrorFromResponse(resp)
|
|
}
|
|
|
|
var acquirableJobList *AcquirableJobList
|
|
err = json.NewDecoder(resp.Body).Decode(&acquirableJobList)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return acquirableJobList, nil
|
|
}
|
|
|
|
func (c *Client) GenerateJitRunnerConfig(ctx context.Context, jitRunnerSetting *RunnerScaleSetJitRunnerSetting, scaleSetId int) (*RunnerScaleSetJitRunnerConfig, error) {
|
|
path := fmt.Sprintf("/%s/%d/generatejitconfig", scaleSetEndpoint, scaleSetId)
|
|
|
|
body, err := json.Marshal(jitRunnerSetting)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req, err := c.NewActionsServiceRequest(ctx, http.MethodPost, path, bytes.NewBuffer(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, ParseActionsErrorFromResponse(resp)
|
|
}
|
|
|
|
var runnerJitConfig *RunnerScaleSetJitRunnerConfig
|
|
err = json.NewDecoder(resp.Body).Decode(&runnerJitConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return runnerJitConfig, nil
|
|
}
|
|
|
|
func (c *Client) GetRunner(ctx context.Context, runnerId int64) (*RunnerReference, error) {
|
|
path := fmt.Sprintf("/%s/%d", runnerEndpoint, runnerId)
|
|
|
|
req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, ParseActionsErrorFromResponse(resp)
|
|
}
|
|
|
|
var runnerReference *RunnerReference
|
|
err = json.NewDecoder(resp.Body).Decode(&runnerReference)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return runnerReference, nil
|
|
}
|
|
|
|
func (c *Client) GetRunnerByName(ctx context.Context, runnerName string) (*RunnerReference, error) {
|
|
path := fmt.Sprintf("/%s?agentName=%s", runnerEndpoint, runnerName)
|
|
|
|
req, err := c.NewActionsServiceRequest(ctx, http.MethodGet, path, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, ParseActionsErrorFromResponse(resp)
|
|
}
|
|
|
|
var runnerList *RunnerReferenceList
|
|
err = json.NewDecoder(resp.Body).Decode(&runnerList)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if runnerList.Count == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
if runnerList.Count > 1 {
|
|
return nil, fmt.Errorf("multiple runner found with name %s", runnerName)
|
|
}
|
|
|
|
return &runnerList.RunnerReferences[0], nil
|
|
}
|
|
|
|
func (c *Client) RemoveRunner(ctx context.Context, runnerId int64) error {
|
|
path := fmt.Sprintf("/%s/%d", runnerEndpoint, runnerId)
|
|
|
|
req, err := c.NewActionsServiceRequest(ctx, http.MethodDelete, path, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusNoContent {
|
|
return ParseActionsErrorFromResponse(resp)
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
return nil
|
|
}
|
|
|
|
type registrationToken struct {
|
|
Token *string `json:"token,omitempty"`
|
|
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
|
}
|
|
|
|
func (c *Client) getRunnerRegistrationToken(ctx context.Context) (*registrationToken, error) {
|
|
path, err := createRegistrationTokenPath(c.config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
req, err := c.NewGitHubAPIRequest(ctx, http.MethodPost, path, &buf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
bearerToken := ""
|
|
|
|
if c.creds.Token != "" {
|
|
bearerToken = fmt.Sprintf("Bearer %v", c.creds.Token)
|
|
} else {
|
|
accessToken, err := c.fetchAccessToken(ctx, c.config.ConfigURL.String(), c.creds.AppCreds)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
bearerToken = fmt.Sprintf("Bearer %v", accessToken.Token)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/vnd.github.v3+json")
|
|
req.Header.Set("Authorization", bearerToken)
|
|
|
|
c.logger.Info("getting runner registration token", "registrationTokenURL", req.URL.String())
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusCreated {
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, fmt.Errorf("unexpected response from Actions service during registration token call: %v - %v", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var registrationToken *registrationToken
|
|
if err := json.NewDecoder(resp.Body).Decode(®istrationToken); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return registrationToken, nil
|
|
}
|
|
|
|
// Format: https://docs.github.com/en/rest/apps/apps#create-an-installation-access-token-for-an-app
|
|
type accessToken struct {
|
|
Token string `json:"token"`
|
|
ExpiresAt time.Time `json:"expires_at"`
|
|
}
|
|
|
|
func (c *Client) fetchAccessToken(ctx context.Context, gitHubConfigURL string, creds *GitHubAppAuth) (*accessToken, error) {
|
|
accessTokenJWT, err := createJWTForGitHubApp(creds)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
path := fmt.Sprintf("/app/installations/%v/access_tokens", creds.AppInstallationID)
|
|
req, err := c.NewGitHubAPIRequest(ctx, http.MethodPost, path, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/vnd.github+json")
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessTokenJWT))
|
|
|
|
c.logger.Info("getting access token for GitHub App auth", "accessTokenURL", req.URL.String())
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Format: https://docs.github.com/en/rest/apps/apps#create-an-installation-access-token-for-an-app
|
|
var accessToken *accessToken
|
|
err = json.NewDecoder(resp.Body).Decode(&accessToken)
|
|
return accessToken, err
|
|
}
|
|
|
|
type ActionsServiceAdminConnection struct {
|
|
ActionsServiceUrl *string `json:"url,omitempty"`
|
|
AdminToken *string `json:"token,omitempty"`
|
|
}
|
|
|
|
func (c *Client) getActionsServiceAdminConnection(ctx context.Context, rt *registrationToken) (*ActionsServiceAdminConnection, error) {
|
|
path := "/actions/runner-registration"
|
|
|
|
body := struct {
|
|
Url string `json:"url"`
|
|
RunnerEvent string `json:"runner_event"`
|
|
}{
|
|
Url: c.config.ConfigURL.String(),
|
|
RunnerEvent: "register",
|
|
}
|
|
|
|
buf := &bytes.Buffer{}
|
|
enc := json.NewEncoder(buf)
|
|
enc.SetEscapeHTML(false)
|
|
|
|
if err := enc.Encode(body); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req, err := c.NewGitHubAPIRequest(ctx, http.MethodPost, path, buf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", fmt.Sprintf("RemoteAuth %s", *rt.Token))
|
|
|
|
c.logger.Info("getting Actions tenant URL and JWT", "registrationURL", req.URL.String())
|
|
|
|
resp, err := c.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
|
registrationErr := fmt.Errorf("unexpected response from Actions service during registration call: %v", resp.StatusCode)
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%v - %v", registrationErr, err)
|
|
}
|
|
return nil, fmt.Errorf("%v - %v", registrationErr, string(body))
|
|
}
|
|
|
|
var actionsServiceAdminConnection *ActionsServiceAdminConnection
|
|
if err := json.NewDecoder(resp.Body).Decode(&actionsServiceAdminConnection); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return actionsServiceAdminConnection, nil
|
|
}
|
|
|
|
func createRegistrationTokenPath(config *GitHubConfig) (string, error) {
|
|
switch config.Scope {
|
|
case GitHubScopeOrganization:
|
|
path := fmt.Sprintf("/orgs/%s/actions/runners/registration-token", config.Organization)
|
|
return path, nil
|
|
|
|
case GitHubScopeEnterprise:
|
|
path := fmt.Sprintf("/enterprises/%s/actions/runners/registration-token", config.Enterprise)
|
|
return path, nil
|
|
|
|
case GitHubScopeRepository:
|
|
path := fmt.Sprintf("/repos/%s/%s/actions/runners/registration-token", config.Organization, config.Repository)
|
|
return path, nil
|
|
|
|
default:
|
|
return "", fmt.Errorf("unknown scope for config url: %s", config.ConfigURL)
|
|
}
|
|
}
|
|
|
|
func createJWTForGitHubApp(appAuth *GitHubAppAuth) (string, error) {
|
|
// Encode as JWT
|
|
// See https://docs.github.com/en/developers/apps/building-github-apps/authenticating-with-github-apps#authenticating-as-a-github-app
|
|
|
|
// Going back in time a bit helps with clock skew.
|
|
issuedAt := time.Now().Add(-60 * time.Second)
|
|
// Max expiration date is 10 minutes.
|
|
expiresAt := issuedAt.Add(9 * time.Minute)
|
|
claims := &jwt.RegisteredClaims{
|
|
IssuedAt: jwt.NewNumericDate(issuedAt),
|
|
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
|
Issuer: strconv.FormatInt(appAuth.AppID, 10),
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
|
|
|
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(appAuth.AppPrivateKey))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return token.SignedString(privateKey)
|
|
}
|
|
|
|
// Returns slice of body without utf-8 byte order mark.
|
|
// If BOM does not exist body is returned unchanged.
|
|
func trimByteOrderMark(body []byte) []byte {
|
|
return bytes.TrimPrefix(body, []byte("\xef\xbb\xbf"))
|
|
}
|
|
|
|
func actionsServiceAdminTokenExpiresAt(jwtToken string) (time.Time, error) {
|
|
type JwtClaims struct {
|
|
jwt.RegisteredClaims
|
|
}
|
|
token, _, err := jwt.NewParser().ParseUnverified(jwtToken, &JwtClaims{})
|
|
if err != nil {
|
|
return time.Time{}, fmt.Errorf("failed to parse jwt token: %w", err)
|
|
}
|
|
|
|
if claims, ok := token.Claims.(*JwtClaims); ok {
|
|
return claims.ExpiresAt.Time, nil
|
|
}
|
|
|
|
return time.Time{}, fmt.Errorf("failed to parse token claims to get expire at")
|
|
}
|
|
|
|
func (c *Client) updateTokenIfNeeded(ctx context.Context) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
aboutToExpire := time.Now().Add(60 * time.Second).After(c.ActionsServiceAdminTokenExpiresAt)
|
|
if !aboutToExpire && !c.ActionsServiceAdminTokenExpiresAt.IsZero() {
|
|
return nil
|
|
}
|
|
|
|
c.logger.Info("refreshing token", "githubConfigUrl", c.config.ConfigURL.String())
|
|
rt, err := c.getRunnerRegistrationToken(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get runner registration token on refresh: %w", err)
|
|
}
|
|
|
|
adminConnInfo, err := c.getActionsServiceAdminConnection(ctx, rt)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get actions service admin connection on refresh: %w", err)
|
|
}
|
|
|
|
c.ActionsServiceURL = *adminConnInfo.ActionsServiceUrl
|
|
c.ActionsServiceAdminToken = *adminConnInfo.AdminToken
|
|
c.ActionsServiceAdminTokenExpiresAt, err = actionsServiceAdminTokenExpiresAt(*adminConnInfo.AdminToken)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get admin token expire at on refresh: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|