From b78cadd90105030e122f78758ba7c3c283a931ab Mon Sep 17 00:00:00 2001 From: Nikola Jokic Date: Fri, 8 Dec 2023 13:41:06 +0100 Subject: [PATCH] Refactoring listener app with configurable fallback (#3096) --- Dockerfile | 2 + cmd/ghalistener/app/app.go | 133 ++++ cmd/ghalistener/app/app_test.go | 85 +++ cmd/ghalistener/app/mocks/listener.go | 43 ++ cmd/ghalistener/app/mocks/worker.go | 58 ++ cmd/ghalistener/config/config.go | 147 +++++ cmd/ghalistener/config/config_test.go | 92 +++ cmd/ghalistener/listener/listener.go | 388 +++++++++++ cmd/ghalistener/listener/listener_test.go | 613 ++++++++++++++++++ cmd/ghalistener/listener/mocks/client.go | 176 +++++ cmd/ghalistener/listener/mocks/handler.go | 58 ++ cmd/ghalistener/main.go | 40 ++ cmd/ghalistener/metrics/metrics.go | 387 +++++++++++ cmd/ghalistener/metrics/mocks/publisher.go | 53 ++ .../metrics/mocks/server_publisher.go | 69 ++ cmd/ghalistener/worker/worker.go | 228 +++++++ .../mock_KubernetesManager.go | 2 +- .../mock_RunnerScaleSetClient.go | 2 +- .../actions.github.com/resourcebuilder.go | 15 +- github/actions/mock_ActionsService.go | 2 +- github/actions/mock_SessionService.go | 2 +- main.go | 2 + 22 files changed, 2590 insertions(+), 7 deletions(-) create mode 100644 cmd/ghalistener/app/app.go create mode 100644 cmd/ghalistener/app/app_test.go create mode 100644 cmd/ghalistener/app/mocks/listener.go create mode 100644 cmd/ghalistener/app/mocks/worker.go create mode 100644 cmd/ghalistener/config/config.go create mode 100644 cmd/ghalistener/config/config_test.go create mode 100644 cmd/ghalistener/listener/listener.go create mode 100644 cmd/ghalistener/listener/listener_test.go create mode 100644 cmd/ghalistener/listener/mocks/client.go create mode 100644 cmd/ghalistener/listener/mocks/handler.go create mode 100644 cmd/ghalistener/main.go create mode 100644 cmd/ghalistener/metrics/metrics.go create mode 100644 cmd/ghalistener/metrics/mocks/publisher.go create mode 100644 cmd/ghalistener/metrics/mocks/server_publisher.go create mode 100644 cmd/ghalistener/worker/worker.go diff --git a/Dockerfile b/Dockerfile index fdeac8ef..e3c6e2f1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -38,6 +38,7 @@ RUN --mount=target=. \ export GOOS=${TARGETOS} GOARCH=${TARGETARCH} GOARM=${TARGETVARIANT#v} && \ go build -trimpath -ldflags="-s -w -X 'github.com/actions/actions-runner-controller/build.Version=${VERSION}' -X 'github.com/actions/actions-runner-controller/build.CommitSHA=${COMMIT_SHA}'" -o /out/manager main.go && \ go build -trimpath -ldflags="-s -w -X 'github.com/actions/actions-runner-controller/build.Version=${VERSION}' -X 'github.com/actions/actions-runner-controller/build.CommitSHA=${COMMIT_SHA}'" -o /out/github-runnerscaleset-listener ./cmd/githubrunnerscalesetlistener && \ + go build -trimpath -ldflags="-s -w -X 'github.com/actions/actions-runner-controller/build.Version=${VERSION}' -X 'github.com/actions/actions-runner-controller/build.CommitSHA=${COMMIT_SHA}'" -o /out/ghalistener ./cmd/ghalistener && \ go build -trimpath -ldflags="-s -w" -o /out/github-webhook-server ./cmd/githubwebhookserver && \ go build -trimpath -ldflags="-s -w" -o /out/actions-metrics-server ./cmd/actionsmetricsserver && \ go build -trimpath -ldflags="-s -w" -o /out/sleep ./cmd/sleep @@ -52,6 +53,7 @@ COPY --from=builder /out/manager . COPY --from=builder /out/github-webhook-server . COPY --from=builder /out/actions-metrics-server . COPY --from=builder /out/github-runnerscaleset-listener . +COPY --from=builder /out/ghalistener . COPY --from=builder /out/sleep . USER 65532:65532 diff --git a/cmd/ghalistener/app/app.go b/cmd/ghalistener/app/app.go new file mode 100644 index 00000000..e8f64f21 --- /dev/null +++ b/cmd/ghalistener/app/app.go @@ -0,0 +1,133 @@ +package app + +import ( + "context" + "errors" + "fmt" + + "github.com/actions/actions-runner-controller/cmd/ghalistener/config" + "github.com/actions/actions-runner-controller/cmd/ghalistener/listener" + "github.com/actions/actions-runner-controller/cmd/ghalistener/metrics" + "github.com/actions/actions-runner-controller/cmd/ghalistener/worker" + "github.com/actions/actions-runner-controller/github/actions" + "github.com/go-logr/logr" + "golang.org/x/sync/errgroup" +) + +// App is responsible for initializing required components and running the app. +type App struct { + // configured fields + config config.Config + logger logr.Logger + + // initialized fields + listener Listener + worker Worker + metrics metrics.ServerPublisher +} + +//go:generate mockery --name Listener --output ./mocks --outpkg mocks --case underscore +type Listener interface { + Listen(ctx context.Context, handler listener.Handler) error +} + +//go:generate mockery --name Worker --output ./mocks --outpkg mocks --case underscore +type Worker interface { + HandleJobStarted(ctx context.Context, jobInfo *actions.JobStarted) error + HandleDesiredRunnerCount(ctx context.Context, desiredRunnerCount int) error +} + +func New(config config.Config) (*App, error) { + app := &App{ + config: config, + } + + ghConfig, err := actions.ParseGitHubConfigFromURL(config.ConfigureUrl) + if err != nil { + return nil, fmt.Errorf("failed to parse GitHub config from URL: %w", err) + } + + { + logger, err := config.Logger() + if err != nil { + return nil, fmt.Errorf("failed to create logger: %w", err) + } + app.logger = logger.WithName("listener-app") + } + + actionsClient, err := config.ActionsClient(app.logger) + if err != nil { + return nil, fmt.Errorf("failed to create actions client: %w", err) + } + + if config.MetricsAddr != "" { + app.metrics = metrics.NewExporter(metrics.ExporterConfig{ + ScaleSetName: config.EphemeralRunnerSetName, + ScaleSetNamespace: config.EphemeralRunnerSetNamespace, + Enterprise: ghConfig.Enterprise, + Organization: ghConfig.Organization, + Repository: ghConfig.Repository, + ServerAddr: config.MetricsAddr, + ServerEndpoint: config.MetricsEndpoint, + }) + } + + worker, err := worker.New( + worker.Config{ + EphemeralRunnerSetNamespace: config.EphemeralRunnerSetNamespace, + EphemeralRunnerSetName: config.EphemeralRunnerSetName, + MaxRunners: config.MaxRunners, + MinRunners: config.MinRunners, + }, + worker.WithLogger(app.logger.WithName("worker")), + ) + if err != nil { + return nil, fmt.Errorf("failed to create new kubernetes worker: %w", err) + } + app.worker = worker + + listener, err := listener.New(listener.Config{ + Client: actionsClient, + ScaleSetID: app.config.RunnerScaleSetId, + MinRunners: app.config.MinRunners, + MaxRunners: app.config.MaxRunners, + Logger: app.logger.WithName("listener"), + Metrics: app.metrics, + }) + if err != nil { + return nil, fmt.Errorf("failed to create new listener: %w", err) + } + app.listener = listener + + app.logger.Info("app initialized") + + return app, nil +} + +func (app *App) Run(ctx context.Context) error { + var errs []error + if app.worker == nil { + errs = append(errs, fmt.Errorf("worker not initialized")) + } + if app.listener == nil { + errs = append(errs, fmt.Errorf("listener not initialized")) + } + if err := errors.Join(errs...); err != nil { + return fmt.Errorf("app not initialized: %w", err) + } + + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { + app.logger.Info("Starting listener") + return app.listener.Listen(ctx, app.worker) + }) + + if app.metrics != nil { + g.Go(func() error { + app.logger.Info("Starting metrics server") + return app.metrics.ListenAndServe(ctx) + }) + } + + return g.Wait() +} diff --git a/cmd/ghalistener/app/app_test.go b/cmd/ghalistener/app/app_test.go new file mode 100644 index 00000000..883add35 --- /dev/null +++ b/cmd/ghalistener/app/app_test.go @@ -0,0 +1,85 @@ +package app + +import ( + "context" + "errors" + "testing" + + appmocks "github.com/actions/actions-runner-controller/cmd/ghalistener/app/mocks" + "github.com/actions/actions-runner-controller/cmd/ghalistener/listener" + metricsMocks "github.com/actions/actions-runner-controller/cmd/ghalistener/metrics/mocks" + "github.com/actions/actions-runner-controller/cmd/ghalistener/worker" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestApp_Run(t *testing.T) { + t.Parallel() + + t.Run("ListenerWorkerGuard", func(t *testing.T) { + invalidApps := []*App{ + {}, + {worker: &worker.Worker{}}, + {listener: &listener.Listener{}}, + } + + for _, app := range invalidApps { + assert.Error(t, app.Run(context.Background())) + } + }) + + t.Run("ExitsOnListenerError", func(t *testing.T) { + listener := appmocks.NewListener(t) + worker := appmocks.NewWorker(t) + + listener.On("Listen", mock.Anything, mock.Anything).Return(errors.New("listener error")).Once() + + app := &App{ + listener: listener, + worker: worker, + } + + err := app.Run(context.Background()) + assert.Error(t, err) + }) + + t.Run("ExitsOnListenerNil", func(t *testing.T) { + listener := appmocks.NewListener(t) + worker := appmocks.NewWorker(t) + + listener.On("Listen", mock.Anything, mock.Anything).Return(nil).Once() + + app := &App{ + listener: listener, + worker: worker, + } + + err := app.Run(context.Background()) + assert.NoError(t, err) + }) + + t.Run("CancelListenerOnMetricsServerError", func(t *testing.T) { + listener := appmocks.NewListener(t) + worker := appmocks.NewWorker(t) + metrics := metricsMocks.NewServerPublisher(t) + ctx := context.Background() + + listener.On("Listen", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + ctx := args.Get(0).(context.Context) + go func() { + <-ctx.Done() + }() + }).Return(nil).Once() + + metrics.On("ListenAndServe", mock.Anything).Return(errors.New("metrics server error")).Once() + + app := &App{ + listener: listener, + worker: worker, + metrics: metrics, + } + + err := app.Run(ctx) + assert.Error(t, err) + }) +} diff --git a/cmd/ghalistener/app/mocks/listener.go b/cmd/ghalistener/app/mocks/listener.go new file mode 100644 index 00000000..c177ace6 --- /dev/null +++ b/cmd/ghalistener/app/mocks/listener.go @@ -0,0 +1,43 @@ +// Code generated by mockery v2.36.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + listener "github.com/actions/actions-runner-controller/cmd/ghalistener/listener" + mock "github.com/stretchr/testify/mock" +) + +// Listener is an autogenerated mock type for the Listener type +type Listener struct { + mock.Mock +} + +// Listen provides a mock function with given fields: ctx, handler +func (_m *Listener) Listen(ctx context.Context, handler listener.Handler) error { + ret := _m.Called(ctx, handler) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, listener.Handler) error); ok { + r0 = rf(ctx, handler) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewListener creates a new instance of Listener. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewListener(t interface { + mock.TestingT + Cleanup(func()) +}) *Listener { + mock := &Listener{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/cmd/ghalistener/app/mocks/worker.go b/cmd/ghalistener/app/mocks/worker.go new file mode 100644 index 00000000..a2561adb --- /dev/null +++ b/cmd/ghalistener/app/mocks/worker.go @@ -0,0 +1,58 @@ +// Code generated by mockery v2.36.1. DO NOT EDIT. + +package mocks + +import ( + actions "github.com/actions/actions-runner-controller/github/actions" + + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// Worker is an autogenerated mock type for the Worker type +type Worker struct { + mock.Mock +} + +// HandleDesiredRunnerCount provides a mock function with given fields: ctx, desiredRunnerCount +func (_m *Worker) HandleDesiredRunnerCount(ctx context.Context, desiredRunnerCount int) error { + ret := _m.Called(ctx, desiredRunnerCount) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { + r0 = rf(ctx, desiredRunnerCount) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// HandleJobStarted provides a mock function with given fields: ctx, jobInfo +func (_m *Worker) HandleJobStarted(ctx context.Context, jobInfo *actions.JobStarted) error { + ret := _m.Called(ctx, jobInfo) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *actions.JobStarted) error); ok { + r0 = rf(ctx, jobInfo) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewWorker creates a new instance of Worker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewWorker(t interface { + mock.TestingT + Cleanup(func()) +}) *Worker { + mock := &Worker{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/cmd/ghalistener/config/config.go b/cmd/ghalistener/config/config.go new file mode 100644 index 00000000..cc229838 --- /dev/null +++ b/cmd/ghalistener/config/config.go @@ -0,0 +1,147 @@ +package config + +import ( + "crypto/x509" + "encoding/json" + "fmt" + "os" + + "github.com/actions/actions-runner-controller/build" + "github.com/actions/actions-runner-controller/github/actions" + "github.com/actions/actions-runner-controller/logging" + "github.com/go-logr/logr" +) + +type Config struct { + ConfigureUrl string `json:"configureUrl"` + AppID int64 `json:"appID"` + AppInstallationID int64 `json:"appInstallationID"` + AppPrivateKey string `json:"appPrivateKey"` + Token string `json:"token"` + EphemeralRunnerSetNamespace string `json:"ephemeralRunnerSetNamespace"` + EphemeralRunnerSetName string `json:"ephemeralRunnerSetName"` + MaxRunners int `json:"maxRunners"` + MinRunners int `json:"minRunners"` + RunnerScaleSetId int `json:"runnerScaleSetId"` + RunnerScaleSetName string `json:"runnerScaleSetName"` + ServerRootCA string `json:"serverRootCA"` + LogLevel string `json:"logLevel"` + LogFormat string `json:"logFormat"` + MetricsAddr string `json:"metricsAddr"` + MetricsEndpoint string `json:"metricsEndpoint"` +} + +func Read(path string) (Config, error) { + f, err := os.Open(path) + if err != nil { + return Config{}, err + } + defer f.Close() + + var config Config + if err := json.NewDecoder(f).Decode(&config); err != nil { + return Config{}, fmt.Errorf("failed to decode config: %w", err) + } + + if err := config.validate(); err != nil { + return Config{}, fmt.Errorf("failed to validate config: %w", err) + } + + return config, nil +} + +func (c *Config) validate() error { + if len(c.ConfigureUrl) == 0 { + return fmt.Errorf("GitHubConfigUrl is not provided") + } + + if len(c.EphemeralRunnerSetNamespace) == 0 || len(c.EphemeralRunnerSetName) == 0 { + return fmt.Errorf("EphemeralRunnerSetNamespace '%s' or EphemeralRunnerSetName '%s' is missing", c.EphemeralRunnerSetNamespace, c.EphemeralRunnerSetName) + } + + if c.RunnerScaleSetId == 0 { + return fmt.Errorf("RunnerScaleSetId '%d' is missing", c.RunnerScaleSetId) + } + + if c.MaxRunners < c.MinRunners { + return fmt.Errorf("MinRunners '%d' cannot be greater than MaxRunners '%d'", c.MinRunners, c.MaxRunners) + } + + hasToken := len(c.Token) > 0 + hasPrivateKeyConfig := c.AppID > 0 && c.AppPrivateKey != "" + + if !hasToken && !hasPrivateKeyConfig { + return fmt.Errorf("GitHub auth credential is missing, token length: '%d', appId: '%d', installationId: '%d', private key length: '%d", len(c.Token), c.AppID, c.AppInstallationID, len(c.AppPrivateKey)) + } + + if hasToken && hasPrivateKeyConfig { + return fmt.Errorf("only one GitHub auth method supported at a time. Have both PAT and App auth: token length: '%d', appId: '%d', installationId: '%d', private key length: '%d", len(c.Token), c.AppID, c.AppInstallationID, len(c.AppPrivateKey)) + } + + return nil +} + +func (c *Config) Logger() (logr.Logger, error) { + logLevel := string(logging.LogLevelDebug) + if c.LogLevel != "" { + logLevel = c.LogLevel + } + + logFormat := string(logging.LogFormatText) + if c.LogFormat != "" { + logFormat = c.LogFormat + } + + logger, err := logging.NewLogger(logLevel, logFormat) + if err != nil { + return logr.Logger{}, fmt.Errorf("NewLogger failed: %w", err) + } + + return logger, nil +} + +func (c *Config) ActionsClient(logger logr.Logger) (*actions.Client, error) { + + var creds actions.ActionsAuth + switch c.Token { + case "": + creds.AppCreds = &actions.GitHubAppAuth{ + AppID: c.AppID, + AppInstallationID: c.AppInstallationID, + AppPrivateKey: c.AppPrivateKey, + } + default: + creds.Token = c.Token + } + + options := []actions.ClientOption{ + actions.WithLogger(logger), + } + + if c.ServerRootCA != "" { + systemPool, err := x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("failed to load system cert pool: %w", err) + } + pool := systemPool.Clone() + ok := pool.AppendCertsFromPEM([]byte(c.ServerRootCA)) + if !ok { + return nil, fmt.Errorf("failed to parse root certificate") + } + + options = append(options, actions.WithRootCAs(pool)) + } + + client, err := actions.NewClient(c.ConfigureUrl, &creds, options...) + if err != nil { + return nil, fmt.Errorf("failed to create actions client: %w", err) + } + + client.SetUserAgent(actions.UserAgentInfo{ + Version: build.Version, + CommitSHA: build.CommitSHA, + ScaleSetID: c.RunnerScaleSetId, + }) + + return client, nil +} diff --git a/cmd/ghalistener/config/config_test.go b/cmd/ghalistener/config/config_test.go new file mode 100644 index 00000000..99e6ac99 --- /dev/null +++ b/cmd/ghalistener/config/config_test.go @@ -0,0 +1,92 @@ +package config + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConfigValidationMinMax(t *testing.T) { + config := &Config{ + ConfigureUrl: "github.com/some_org/some_repo", + EphemeralRunnerSetNamespace: "namespace", + EphemeralRunnerSetName: "deployment", + RunnerScaleSetId: 1, + MinRunners: 5, + MaxRunners: 2, + Token: "token", + } + err := config.validate() + assert.ErrorContains(t, err, "MinRunners '5' cannot be greater than MaxRunners '2", "Expected error about MinRunners > MaxRunners") +} + +func TestConfigValidationMissingToken(t *testing.T) { + config := &Config{ + ConfigureUrl: "github.com/some_org/some_repo", + EphemeralRunnerSetNamespace: "namespace", + EphemeralRunnerSetName: "deployment", + RunnerScaleSetId: 1, + } + err := config.validate() + expectedError := fmt.Sprintf("GitHub auth credential is missing, token length: '%d', appId: '%d', installationId: '%d', private key length: '%d", len(config.Token), config.AppID, config.AppInstallationID, len(config.AppPrivateKey)) + assert.ErrorContains(t, err, expectedError, "Expected error about missing auth") +} + +func TestConfigValidationAppKey(t *testing.T) { + config := &Config{ + AppID: 1, + AppInstallationID: 10, + ConfigureUrl: "github.com/some_org/some_repo", + EphemeralRunnerSetNamespace: "namespace", + EphemeralRunnerSetName: "deployment", + RunnerScaleSetId: 1, + } + err := config.validate() + expectedError := fmt.Sprintf("GitHub auth credential is missing, token length: '%d', appId: '%d', installationId: '%d', private key length: '%d", len(config.Token), config.AppID, config.AppInstallationID, len(config.AppPrivateKey)) + assert.ErrorContains(t, err, expectedError, "Expected error about missing auth") +} + +func TestConfigValidationOnlyOneTypeOfCredentials(t *testing.T) { + config := &Config{ + AppID: 1, + AppInstallationID: 10, + AppPrivateKey: "asdf", + Token: "asdf", + ConfigureUrl: "github.com/some_org/some_repo", + EphemeralRunnerSetNamespace: "namespace", + EphemeralRunnerSetName: "deployment", + RunnerScaleSetId: 1, + } + err := config.validate() + expectedError := fmt.Sprintf("only one GitHub auth method supported at a time. Have both PAT and App auth: token length: '%d', appId: '%d', installationId: '%d', private key length: '%d", len(config.Token), config.AppID, config.AppInstallationID, len(config.AppPrivateKey)) + assert.ErrorContains(t, err, expectedError, "Expected error about missing auth") +} + +func TestConfigValidation(t *testing.T) { + config := &Config{ + ConfigureUrl: "https://github.com/actions", + EphemeralRunnerSetNamespace: "namespace", + EphemeralRunnerSetName: "deployment", + RunnerScaleSetId: 1, + MinRunners: 1, + MaxRunners: 5, + Token: "asdf", + } + + err := config.validate() + + assert.NoError(t, err, "Expected no error") +} + +func TestConfigValidationConfigUrl(t *testing.T) { + config := &Config{ + EphemeralRunnerSetNamespace: "namespace", + EphemeralRunnerSetName: "deployment", + RunnerScaleSetId: 1, + } + + err := config.validate() + + assert.ErrorContains(t, err, "GitHubConfigUrl is not provided", "Expected error about missing ConfigureUrl") +} diff --git a/cmd/ghalistener/listener/listener.go b/cmd/ghalistener/listener/listener.go new file mode 100644 index 00000000..e90622a0 --- /dev/null +++ b/cmd/ghalistener/listener/listener.go @@ -0,0 +1,388 @@ +package listener + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "time" + + "github.com/actions/actions-runner-controller/cmd/ghalistener/metrics" + "github.com/actions/actions-runner-controller/github/actions" + "github.com/go-logr/logr" + "github.com/google/uuid" +) + +const ( + sessionCreationMaxRetries = 10 +) + +// message types +const ( + messageTypeJobAvailable = "JobAvailable" + messageTypeJobAssigned = "JobAssigned" + messageTypeJobStarted = "JobStarted" + messageTypeJobCompleted = "JobCompleted" +) + +//go:generate mockery --name Client --output ./mocks --outpkg mocks --case underscore +type Client interface { + GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (*actions.AcquirableJobList, error) + CreateMessageSession(ctx context.Context, runnerScaleSetId int, owner string) (*actions.RunnerScaleSetSession, error) + GetMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, lastMessageId int64) (*actions.RunnerScaleSetMessage, error) + DeleteMessage(ctx context.Context, messageQueueUrl, messageQueueAccessToken string, messageId int64) error + AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQueueAccessToken string, requestIds []int64) ([]int64, error) + RefreshMessageSession(ctx context.Context, runnerScaleSetId int, sessionId *uuid.UUID) (*actions.RunnerScaleSetSession, error) +} + +type Config struct { + Client Client + ScaleSetID int + MinRunners int + MaxRunners int + Logger logr.Logger + Metrics metrics.Publisher +} + +func (c *Config) Validate() error { + if c.Client == nil { + return errors.New("client is required") + } + if c.ScaleSetID == 0 { + return errors.New("scaleSetID is required") + } + if c.MinRunners < 0 { + return errors.New("minRunners must be greater than or equal to 0") + } + if c.MaxRunners < 0 { + return errors.New("maxRunners must be greater than or equal to 0") + } + if c.MaxRunners > 0 && c.MinRunners > c.MaxRunners { + return errors.New("minRunners must be less than or equal to maxRunners") + } + return nil +} + +// The Listener's role is to manage all interactions with the actions service. +// It receives messages and processes them using the given handler. +type Listener struct { + // configured fields + scaleSetID int // The ID of the scale set associated with the listener. + client Client // The client used to interact with the scale set. + metrics metrics.Publisher // The publisher used to publish metrics. + + // internal fields + logger logr.Logger // The logger used for logging. + hostname string // The hostname of the listener. + + // updated fields + lastMessageID int64 // The ID of the last processed message. + session *actions.RunnerScaleSetSession // The session for managing the runner scale set. +} + +func New(config Config) (*Listener, error) { + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("invalid config: %w", err) + } + + listener := &Listener{ + scaleSetID: config.ScaleSetID, + client: config.Client, + logger: config.Logger, + metrics: metrics.Discard, + } + + if config.Metrics != nil { + listener.metrics = config.Metrics + } + + listener.metrics.PublishStatic(config.MinRunners, config.MaxRunners) + + hostname, err := os.Hostname() + if err != nil { + hostname = uuid.NewString() + listener.logger.Info("Failed to get hostname, fallback to uuid", "uuid", hostname, "error", err) + } + listener.hostname = hostname + + return listener, nil +} + +//go:generate mockery --name Handler --output ./mocks --outpkg mocks --case underscore +type Handler interface { + HandleJobStarted(ctx context.Context, jobInfo *actions.JobStarted) error + HandleDesiredRunnerCount(ctx context.Context, desiredRunnerCount int) error +} + +// Listen listens for incoming messages and handles them using the provided handler. +// It continuously listens for messages until the context is cancelled. +// The initial message contains the current statistics and acquirable jobs, if any. +// The handler is responsible for handling the initial message and subsequent messages. +// If an error occurs during any step, Listen returns an error. +func (l *Listener) Listen(ctx context.Context, handler Handler) error { + if err := l.createSession(ctx); err != nil { + return fmt.Errorf("createSession failed: %w", err) + } + + initialMessage := &actions.RunnerScaleSetMessage{ + MessageId: 0, + MessageType: "RunnerScaleSetJobMessages", + Statistics: l.session.Statistics, + Body: "", + } + + if l.session.Statistics.TotalAvailableJobs > 0 || l.session.Statistics.TotalAssignedJobs > 0 { + acquirableJobs, err := l.client.GetAcquirableJobs(ctx, l.scaleSetID) + if err != nil { + return fmt.Errorf("failed to call GetAcquirableJobs: %w", err) + } + + acquirableJobsJson, err := json.Marshal(acquirableJobs) + if err != nil { + return fmt.Errorf("failed to marshal acquirable jobs: %w", err) + } + + initialMessage.Body = string(acquirableJobsJson) + } + + if err := handler.HandleDesiredRunnerCount(ctx, initialMessage.Statistics.TotalAssignedJobs); err != nil { + return fmt.Errorf("handling initial message failed: %w", err) + } + + for { + select { + case <-ctx.Done(): + return fmt.Errorf("context cancelled: %w", ctx.Err()) + default: + } + + msg, err := l.getMessage(ctx) + if err != nil { + return fmt.Errorf("failed to get message: %w", err) + } + + if msg == nil { + continue + } + + statistics, jobsStarted, err := l.parseMessage(ctx, msg) + if err != nil { + return fmt.Errorf("failed to parse message: %w", err) + } + + l.lastMessageID = msg.MessageId + + if err := l.deleteLastMessage(ctx); err != nil { + return fmt.Errorf("failed to delete message: %w", err) + } + + for _, jobStarted := range jobsStarted { + if err := handler.HandleJobStarted(ctx, jobStarted); err != nil { + return fmt.Errorf("failed to handle job started: %w", err) + } + } + + if err := handler.HandleDesiredRunnerCount(ctx, statistics.TotalAssignedJobs); err != nil { + return fmt.Errorf("failed to handle desired runner count: %w", err) + } + } +} + +func (l *Listener) createSession(ctx context.Context) error { + var session *actions.RunnerScaleSetSession + var retries int + + for { + var err error + session, err = l.client.CreateMessageSession(ctx, l.scaleSetID, l.hostname) + if err == nil { + break + } + + clientErr := &actions.HttpClientSideError{} + if !errors.As(err, &clientErr) { + return fmt.Errorf("failed to create session: %w", err) + } + + if clientErr.Code != http.StatusConflict { + return fmt.Errorf("failed to create session: %w", err) + } + + retries++ + if retries >= sessionCreationMaxRetries { + return fmt.Errorf("failed to create session after %d retries: %w", retries, err) + } + + l.logger.Info("Unable to create message session. Will try again in 30 seconds", "error", err.Error()) + + select { + case <-ctx.Done(): + return fmt.Errorf("context cancelled: %w", ctx.Err()) + case <-time.After(30 * time.Second): + } + } + + statistics, err := json.Marshal(session.Statistics) + if err != nil { + return fmt.Errorf("failed to marshal statistics: %w", err) + } + l.logger.Info("Current runner scale set statistics.", "statistics", string(statistics)) + + l.session = session + + return nil +} + +func (l *Listener) getMessage(ctx context.Context) (*actions.RunnerScaleSetMessage, error) { + l.logger.Info("Getting next message", "lastMessageID", l.lastMessageID) + msg, err := l.client.GetMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID) + if err == nil { // if NO error + return msg, nil + } + + expiredError := &actions.MessageQueueTokenExpiredError{} + if !errors.As(err, &expiredError) { + return nil, fmt.Errorf("failed to get next message: %w", err) + } + + if err := l.refreshSession(ctx); err != nil { + return nil, err + } + + l.logger.Info("Getting next message", "lastMessageID", l.lastMessageID) + + msg, err = l.client.GetMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID) + if err != nil { // if NO error + return nil, fmt.Errorf("failed to get next message after message session refresh: %w", err) + } + + return msg, nil + +} + +func (l *Listener) deleteLastMessage(ctx context.Context) error { + l.logger.Info("Deleting last message", "lastMessageID", l.lastMessageID) + if err := l.client.DeleteMessage(ctx, l.session.MessageQueueUrl, l.session.MessageQueueAccessToken, l.lastMessageID); err != nil { + return fmt.Errorf("failed to delete message: %w", err) + } + + return nil +} + +func (l *Listener) parseMessage(ctx context.Context, msg *actions.RunnerScaleSetMessage) (*actions.RunnerScaleSetStatistic, []*actions.JobStarted, error) { + l.logger.Info("Processing message", "messageId", msg.MessageId, "messageType", msg.MessageType) + if msg.Statistics == nil { + return nil, nil, fmt.Errorf("invalid message: statistics is nil") + } + + l.logger.Info("New runner scale set statistics.", "statistics", msg.Statistics) + + if msg.MessageType != "RunnerScaleSetJobMessages" { + l.logger.Info("Skipping message", "messageType", msg.MessageType) + return nil, nil, fmt.Errorf("invalid message type: %s", msg.MessageType) + } + + var batchedMessages []json.RawMessage + if len(msg.Body) > 0 { + if err := json.Unmarshal([]byte(msg.Body), &batchedMessages); err != nil { + return nil, nil, fmt.Errorf("failed to unmarshal batched messages: %w", err) + } + } + + var availableJobs []int64 + var startedJobs []*actions.JobStarted + for _, msg := range batchedMessages { + var messageType actions.JobMessageType + if err := json.Unmarshal(msg, &messageType); err != nil { + return nil, nil, fmt.Errorf("failed to decode job message type: %w", err) + } + + switch messageType.MessageType { + case messageTypeJobAvailable: + var jobAvailable actions.JobAvailable + if err := json.Unmarshal(msg, &jobAvailable); err != nil { + return nil, nil, fmt.Errorf("failed to decode job available: %w", err) + } + + l.logger.Info("Job available message received", "jobId", jobAvailable.RunnerRequestId) + availableJobs = append(availableJobs, jobAvailable.RunnerRequestId) + + case messageTypeJobAssigned: + var jobAssigned actions.JobAssigned + if err := json.Unmarshal(msg, &jobAssigned); err != nil { + return nil, nil, fmt.Errorf("failed to decode job assigned: %w", err) + } + + l.logger.Info("Job assigned message received", "jobId", jobAssigned.RunnerRequestId) + + case messageTypeJobStarted: + var jobStarted actions.JobStarted + if err := json.Unmarshal(msg, &jobStarted); err != nil { + return nil, nil, fmt.Errorf("could not decode job started message. %w", err) + } + l.logger.Info("Job started message received.", "RequestId", jobStarted.RunnerRequestId, "RunnerId", jobStarted.RunnerId) + startedJobs = append(startedJobs, &jobStarted) + + case messageTypeJobCompleted: + var jobCompleted actions.JobCompleted + if err := json.Unmarshal(msg, &jobCompleted); err != nil { + return nil, nil, fmt.Errorf("failed to decode job completed: %w", err) + } + + l.logger.Info("Job completed message received.", "RequestId", jobCompleted.RunnerRequestId, "Result", jobCompleted.Result, "RunnerId", jobCompleted.RunnerId, "RunnerName", jobCompleted.RunnerName) + + default: + l.logger.Info("unknown job message type.", "messageType", messageType.MessageType) + } + } + + l.logger.Info("Available jobs.", "count", len(availableJobs), "requestIds", fmt.Sprint(availableJobs)) + if len(availableJobs) > 0 { + acquired, err := l.acquireAvailableJobs(ctx, availableJobs) + if err != nil { + return nil, nil, err + } + + l.logger.Info("Jobs are acquired", "count", len(acquired), "requestIds", fmt.Sprint(acquired)) + } + + return msg.Statistics, startedJobs, nil +} + +func (l *Listener) acquireAvailableJobs(ctx context.Context, availableJobs []int64) ([]int64, error) { + l.logger.Info("Acquiring jobs") + + ids, err := l.client.AcquireJobs(ctx, l.scaleSetID, l.session.MessageQueueAccessToken, availableJobs) + if err == nil { // if NO errors + return ids, nil + } + + expiredError := &actions.MessageQueueTokenExpiredError{} + if !errors.As(err, &expiredError) { + return nil, fmt.Errorf("failed to acquire jobs: %w", err) + } + + if err := l.refreshSession(ctx); err != nil { + return nil, err + } + + ids, err = l.client.AcquireJobs(ctx, l.scaleSetID, l.session.MessageQueueAccessToken, availableJobs) + if err != nil { + return nil, fmt.Errorf("failed to acquire jobs after session refresh: %w", err) + } + + return ids, nil +} + +func (l *Listener) refreshSession(ctx context.Context) error { + l.logger.Info("Message queue token is expired during GetNextMessage, refreshing...") + session, err := l.client.RefreshMessageSession(ctx, l.session.RunnerScaleSet.Id, l.session.SessionId) + if err != nil { + return fmt.Errorf("refresh message session failed. %w", err) + } + + l.session = session + return nil +} diff --git a/cmd/ghalistener/listener/listener_test.go b/cmd/ghalistener/listener/listener_test.go new file mode 100644 index 00000000..86b69b83 --- /dev/null +++ b/cmd/ghalistener/listener/listener_test.go @@ -0,0 +1,613 @@ +package listener + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + listenermocks "github.com/actions/actions-runner-controller/cmd/ghalistener/listener/mocks" + "github.com/actions/actions-runner-controller/cmd/ghalistener/metrics" + metricsmocks "github.com/actions/actions-runner-controller/cmd/ghalistener/metrics/mocks" + "github.com/actions/actions-runner-controller/github/actions" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestNew(t *testing.T) { + t.Parallel() + t.Run("InvalidConfig", func(t *testing.T) { + t.Parallel() + var config Config + _, err := New(config) + assert.NotNil(t, err) + }) + + t.Run("ValidConfig", func(t *testing.T) { + t.Parallel() + config := Config{ + Client: listenermocks.NewClient(t), + ScaleSetID: 1, + Metrics: metrics.Discard, + } + l, err := New(config) + assert.Nil(t, err) + assert.NotNil(t, l) + }) + + t.Run("SetStaticMetrics", func(t *testing.T) { + t.Parallel() + + metrics := metricsmocks.NewPublisher(t) + + metrics.On("PublishStatic", mock.Anything, mock.Anything).Once() + + config := Config{ + Client: listenermocks.NewClient(t), + ScaleSetID: 1, + Metrics: metrics, + } + l, err := New(config) + assert.Nil(t, err) + assert.NotNil(t, l) + }) +} + +func TestListener_createSession(t *testing.T) { + t.Parallel() + t.Run("FailOnce", func(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + client.On("CreateMessageSession", ctx, mock.Anything, mock.Anything).Return(nil, assert.AnError).Once() + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + err = l.createSession(ctx) + assert.NotNil(t, err) + }) + + t.Run("FailContext", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + client.On("CreateMessageSession", ctx, mock.Anything, mock.Anything).Return(nil, + &actions.HttpClientSideError{Code: http.StatusConflict}).Once() + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + err = l.createSession(ctx) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) + }) + + t.Run("SetsSession", func(t *testing.T) { + t.Parallel() + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + + uuid := uuid.New() + session := &actions.RunnerScaleSetSession{ + SessionId: &uuid, + OwnerName: "example", + RunnerScaleSet: &actions.RunnerScaleSet{}, + MessageQueueUrl: "https://example.com", + MessageQueueAccessToken: "1234567890", + Statistics: nil, + } + client.On("CreateMessageSession", mock.Anything, mock.Anything, mock.Anything).Return(session, nil).Once() + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + err = l.createSession(context.Background()) + assert.Nil(t, err) + assert.Equal(t, session, l.session) + }) +} + +func TestListener_getMessage(t *testing.T) { + t.Parallel() + + t.Run("ReceivesMessage", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + want := &actions.RunnerScaleSetMessage{ + MessageId: 1, + } + client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(want, nil).Once() + config.Client = client + + l, err := New(config) + require.Nil(t, err) + l.session = &actions.RunnerScaleSetSession{} + + got, err := l.getMessage(ctx) + assert.Nil(t, err) + assert.Equal(t, want, got) + }) + + t.Run("NotExpiredError", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil, &actions.HttpClientSideError{Code: http.StatusNotFound}).Once() + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + l.session = &actions.RunnerScaleSetSession{} + + _, err = l.getMessage(ctx) + assert.NotNil(t, err) + }) + + t.Run("RefreshAndSucceeds", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + + uuid := uuid.New() + session := &actions.RunnerScaleSetSession{ + SessionId: &uuid, + OwnerName: "example", + RunnerScaleSet: &actions.RunnerScaleSet{}, + MessageQueueUrl: "https://example.com", + MessageQueueAccessToken: "1234567890", + Statistics: nil, + } + client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() + + client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil, &actions.MessageQueueTokenExpiredError{}).Once() + + want := &actions.RunnerScaleSetMessage{ + MessageId: 1, + } + client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(want, nil).Once() + + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + l.session = &actions.RunnerScaleSetSession{ + SessionId: &uuid, + RunnerScaleSet: &actions.RunnerScaleSet{}, + } + + got, err := l.getMessage(ctx) + assert.Nil(t, err) + assert.Equal(t, want, got) + }) + + t.Run("RefreshAndFails", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + + uuid := uuid.New() + session := &actions.RunnerScaleSetSession{ + SessionId: &uuid, + OwnerName: "example", + RunnerScaleSet: &actions.RunnerScaleSet{}, + MessageQueueUrl: "https://example.com", + MessageQueueAccessToken: "1234567890", + Statistics: nil, + } + client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() + + client.On("GetMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil, &actions.MessageQueueTokenExpiredError{}).Twice() + + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + l.session = &actions.RunnerScaleSetSession{ + SessionId: &uuid, + RunnerScaleSet: &actions.RunnerScaleSet{}, + } + + got, err := l.getMessage(ctx) + assert.NotNil(t, err) + assert.Nil(t, got) + }) +} + +func TestListener_refreshSession(t *testing.T) { + t.Parallel() + + t.Run("SuccessfullyRefreshes", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + + newUUID := uuid.New() + session := &actions.RunnerScaleSetSession{ + SessionId: &newUUID, + OwnerName: "example", + RunnerScaleSet: &actions.RunnerScaleSet{}, + MessageQueueUrl: "https://example.com", + MessageQueueAccessToken: "1234567890", + Statistics: nil, + } + client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() + + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + oldUUID := uuid.New() + l.session = &actions.RunnerScaleSetSession{ + SessionId: &oldUUID, + RunnerScaleSet: &actions.RunnerScaleSet{}, + } + + err = l.refreshSession(ctx) + assert.Nil(t, err) + assert.Equal(t, session, l.session) + }) + + t.Run("FailsToRefresh", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + + client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(nil, errors.New("error")).Once() + + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + oldUUID := uuid.New() + oldSession := &actions.RunnerScaleSetSession{ + SessionId: &oldUUID, + RunnerScaleSet: &actions.RunnerScaleSet{}, + } + l.session = oldSession + + err = l.refreshSession(ctx) + assert.NotNil(t, err) + assert.Equal(t, oldSession, l.session) + }) +} + +func TestListener_deleteLastMessage(t *testing.T) { + t.Parallel() + + t.Run("SuccessfullyDeletes", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + + client.On("DeleteMessage", ctx, mock.Anything, mock.Anything, mock.MatchedBy(func(lastMessageID any) bool { + return lastMessageID.(int64) == int64(5) + })).Return(nil).Once() + + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + l.session = &actions.RunnerScaleSetSession{} + l.lastMessageID = 5 + + err = l.deleteLastMessage(ctx) + assert.Nil(t, err) + }) + + t.Run("FailsToDelete", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + + client.On("DeleteMessage", ctx, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("error")).Once() + + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + l.session = &actions.RunnerScaleSetSession{} + l.lastMessageID = 5 + + err = l.deleteLastMessage(ctx) + assert.NotNil(t, err) + }) +} + +func TestListener_Listen(t *testing.T) { + t.Parallel() + + t.Run("CreateSessionFails", func(t *testing.T) { + t.Parallel() + ctx := context.Background() + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + client.On("CreateMessageSession", ctx, mock.Anything, mock.Anything).Return(nil, assert.AnError).Once() + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + err = l.Listen(ctx, nil) + assert.NotNil(t, err) + }) + + t.Run("CallHandleRegardlessOfInitialMessage", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + + uuid := uuid.New() + session := &actions.RunnerScaleSetSession{ + SessionId: &uuid, + OwnerName: "example", + RunnerScaleSet: &actions.RunnerScaleSet{}, + MessageQueueUrl: "https://example.com", + MessageQueueAccessToken: "1234567890", + Statistics: &actions.RunnerScaleSetStatistic{}, + } + client.On("CreateMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + var called bool + handler := listenermocks.NewHandler(t) + handler.On("HandleDesiredRunnerCount", mock.Anything, mock.Anything). + Return(nil). + Run( + func(mock.Arguments) { + called = true + cancel() + }, + ). + Once() + + err = l.Listen(ctx, handler) + assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, called) + }) +} + +func TestListener_acquireAvailableJobs(t *testing.T) { + t.Parallel() + + t.Run("FailingToAcquireJobs", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + + client.On("AcquireJobs", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil, assert.AnError).Once() + + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + uuid := uuid.New() + l.session = &actions.RunnerScaleSetSession{ + SessionId: &uuid, + OwnerName: "example", + RunnerScaleSet: &actions.RunnerScaleSet{}, + MessageQueueUrl: "https://example.com", + MessageQueueAccessToken: "1234567890", + Statistics: &actions.RunnerScaleSetStatistic{}, + } + + _, err = l.acquireAvailableJobs(ctx, []int64{1, 2, 3}) + assert.Error(t, err) + }) + + t.Run("SuccessfullyAcquiresJobsOnFirstRun", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + + jobIDs := []int64{1, 2, 3} + + client.On("AcquireJobs", ctx, mock.Anything, mock.Anything, mock.Anything).Return(jobIDs, nil).Once() + + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + uuid := uuid.New() + l.session = &actions.RunnerScaleSetSession{ + SessionId: &uuid, + OwnerName: "example", + RunnerScaleSet: &actions.RunnerScaleSet{}, + MessageQueueUrl: "https://example.com", + MessageQueueAccessToken: "1234567890", + Statistics: &actions.RunnerScaleSetStatistic{}, + } + + acquiredJobIDs, err := l.acquireAvailableJobs(ctx, []int64{1, 2, 3}) + assert.NoError(t, err) + assert.Equal(t, jobIDs, acquiredJobIDs) + }) + + t.Run("RefreshAndSucceeds", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + + uuid := uuid.New() + session := &actions.RunnerScaleSetSession{ + SessionId: &uuid, + OwnerName: "example", + RunnerScaleSet: &actions.RunnerScaleSet{}, + MessageQueueUrl: "https://example.com", + MessageQueueAccessToken: "1234567890", + Statistics: nil, + } + client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() + + // First call to AcquireJobs will fail with a token expired error + client.On("AcquireJobs", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil, &actions.MessageQueueTokenExpiredError{}).Once() + + // Second call to AcquireJobs will succeed + want := []int64{1, 2, 3} + client.On("AcquireJobs", ctx, mock.Anything, mock.Anything, mock.Anything).Return(want, nil).Once() + + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + l.session = &actions.RunnerScaleSetSession{ + SessionId: &uuid, + RunnerScaleSet: &actions.RunnerScaleSet{}, + } + + got, err := l.acquireAvailableJobs(ctx, want) + assert.Nil(t, err) + assert.Equal(t, want, got) + }) + + t.Run("RefreshAndFails", func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + config := Config{ + ScaleSetID: 1, + Metrics: metrics.Discard, + } + + client := listenermocks.NewClient(t) + + uuid := uuid.New() + session := &actions.RunnerScaleSetSession{ + SessionId: &uuid, + OwnerName: "example", + RunnerScaleSet: &actions.RunnerScaleSet{}, + MessageQueueUrl: "https://example.com", + MessageQueueAccessToken: "1234567890", + Statistics: nil, + } + client.On("RefreshMessageSession", ctx, mock.Anything, mock.Anything).Return(session, nil).Once() + + client.On("AcquireJobs", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil, &actions.MessageQueueTokenExpiredError{}).Twice() + + config.Client = client + + l, err := New(config) + require.Nil(t, err) + + l.session = &actions.RunnerScaleSetSession{ + SessionId: &uuid, + RunnerScaleSet: &actions.RunnerScaleSet{}, + } + + got, err := l.acquireAvailableJobs(ctx, []int64{1, 2, 3}) + assert.NotNil(t, err) + assert.Nil(t, got) + }) +} diff --git a/cmd/ghalistener/listener/mocks/client.go b/cmd/ghalistener/listener/mocks/client.go new file mode 100644 index 00000000..4a2311ea --- /dev/null +++ b/cmd/ghalistener/listener/mocks/client.go @@ -0,0 +1,176 @@ +// Code generated by mockery v2.36.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + actions "github.com/actions/actions-runner-controller/github/actions" + + mock "github.com/stretchr/testify/mock" + + uuid "github.com/google/uuid" +) + +// Client is an autogenerated mock type for the Client type +type Client struct { + mock.Mock +} + +// AcquireJobs provides a mock function with given fields: ctx, runnerScaleSetId, messageQueueAccessToken, requestIds +func (_m *Client) AcquireJobs(ctx context.Context, runnerScaleSetId int, messageQueueAccessToken string, requestIds []int64) ([]int64, error) { + ret := _m.Called(ctx, runnerScaleSetId, messageQueueAccessToken, requestIds) + + var r0 []int64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int, string, []int64) ([]int64, error)); ok { + return rf(ctx, runnerScaleSetId, messageQueueAccessToken, requestIds) + } + if rf, ok := ret.Get(0).(func(context.Context, int, string, []int64) []int64); ok { + r0 = rf(ctx, runnerScaleSetId, messageQueueAccessToken, requestIds) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int64) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int, string, []int64) error); ok { + r1 = rf(ctx, runnerScaleSetId, messageQueueAccessToken, requestIds) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CreateMessageSession provides a mock function with given fields: ctx, runnerScaleSetId, owner +func (_m *Client) CreateMessageSession(ctx context.Context, runnerScaleSetId int, owner string) (*actions.RunnerScaleSetSession, error) { + ret := _m.Called(ctx, runnerScaleSetId, owner) + + var r0 *actions.RunnerScaleSetSession + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int, string) (*actions.RunnerScaleSetSession, error)); ok { + return rf(ctx, runnerScaleSetId, owner) + } + if rf, ok := ret.Get(0).(func(context.Context, int, string) *actions.RunnerScaleSetSession); ok { + r0 = rf(ctx, runnerScaleSetId, owner) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*actions.RunnerScaleSetSession) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int, string) error); ok { + r1 = rf(ctx, runnerScaleSetId, owner) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeleteMessage provides a mock function with given fields: ctx, messageQueueUrl, messageQueueAccessToken, messageId +func (_m *Client) DeleteMessage(ctx context.Context, messageQueueUrl string, messageQueueAccessToken string, messageId int64) error { + ret := _m.Called(ctx, messageQueueUrl, messageQueueAccessToken, messageId) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) error); ok { + r0 = rf(ctx, messageQueueUrl, messageQueueAccessToken, messageId) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetAcquirableJobs provides a mock function with given fields: ctx, runnerScaleSetId +func (_m *Client) GetAcquirableJobs(ctx context.Context, runnerScaleSetId int) (*actions.AcquirableJobList, error) { + ret := _m.Called(ctx, runnerScaleSetId) + + var r0 *actions.AcquirableJobList + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int) (*actions.AcquirableJobList, error)); ok { + return rf(ctx, runnerScaleSetId) + } + if rf, ok := ret.Get(0).(func(context.Context, int) *actions.AcquirableJobList); ok { + r0 = rf(ctx, runnerScaleSetId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*actions.AcquirableJobList) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, runnerScaleSetId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetMessage provides a mock function with given fields: ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId +func (_m *Client) GetMessage(ctx context.Context, messageQueueUrl string, messageQueueAccessToken string, lastMessageId int64) (*actions.RunnerScaleSetMessage, error) { + ret := _m.Called(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) + + var r0 *actions.RunnerScaleSetMessage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) (*actions.RunnerScaleSetMessage, error)); ok { + return rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, int64) *actions.RunnerScaleSetMessage); ok { + r0 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*actions.RunnerScaleSetMessage) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, int64) error); ok { + r1 = rf(ctx, messageQueueUrl, messageQueueAccessToken, lastMessageId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RefreshMessageSession provides a mock function with given fields: ctx, runnerScaleSetId, sessionId +func (_m *Client) RefreshMessageSession(ctx context.Context, runnerScaleSetId int, sessionId *uuid.UUID) (*actions.RunnerScaleSetSession, error) { + ret := _m.Called(ctx, runnerScaleSetId, sessionId) + + var r0 *actions.RunnerScaleSetSession + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int, *uuid.UUID) (*actions.RunnerScaleSetSession, error)); ok { + return rf(ctx, runnerScaleSetId, sessionId) + } + if rf, ok := ret.Get(0).(func(context.Context, int, *uuid.UUID) *actions.RunnerScaleSetSession); ok { + r0 = rf(ctx, runnerScaleSetId, sessionId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*actions.RunnerScaleSetSession) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int, *uuid.UUID) error); ok { + r1 = rf(ctx, runnerScaleSetId, sessionId) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewClient creates a new instance of Client. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewClient(t interface { + mock.TestingT + Cleanup(func()) +}) *Client { + mock := &Client{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/cmd/ghalistener/listener/mocks/handler.go b/cmd/ghalistener/listener/mocks/handler.go new file mode 100644 index 00000000..c78fe250 --- /dev/null +++ b/cmd/ghalistener/listener/mocks/handler.go @@ -0,0 +1,58 @@ +// Code generated by mockery v2.36.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + actions "github.com/actions/actions-runner-controller/github/actions" + + mock "github.com/stretchr/testify/mock" +) + +// Handler is an autogenerated mock type for the Handler type +type Handler struct { + mock.Mock +} + +// HandleDesiredRunnerCount provides a mock function with given fields: ctx, desiredRunnerCount +func (_m *Handler) HandleDesiredRunnerCount(ctx context.Context, desiredRunnerCount int) error { + ret := _m.Called(ctx, desiredRunnerCount) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { + r0 = rf(ctx, desiredRunnerCount) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// HandleJobStarted provides a mock function with given fields: ctx, jobInfo +func (_m *Handler) HandleJobStarted(ctx context.Context, jobInfo *actions.JobStarted) error { + ret := _m.Called(ctx, jobInfo) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *actions.JobStarted) error); ok { + r0 = rf(ctx, jobInfo) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewHandler creates a new instance of Handler. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewHandler(t interface { + mock.TestingT + Cleanup(func()) +}) *Handler { + mock := &Handler{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/cmd/ghalistener/main.go b/cmd/ghalistener/main.go new file mode 100644 index 00000000..10436b30 --- /dev/null +++ b/cmd/ghalistener/main.go @@ -0,0 +1,40 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + + "github.com/actions/actions-runner-controller/cmd/ghalistener/app" + "github.com/actions/actions-runner-controller/cmd/ghalistener/config" +) + +func main() { + configPath, ok := os.LookupEnv("LISTENER_CONFIG_PATH") + if !ok { + fmt.Fprintf(os.Stderr, "Error: LISTENER_CONFIG_PATH environment variable is not set\n") + os.Exit(1) + } + config, err := config.Read(configPath) + if err != nil { + log.Printf("Failed to read config: %v", err) + os.Exit(1) + } + + app, err := app.New(config) + if err != nil { + log.Printf("Failed to initialize app: %v", err) + os.Exit(1) + } + + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + if err := app.Run(ctx); err != nil { + log.Printf("Application returned an error: %v", err) + os.Exit(1) + } +} diff --git a/cmd/ghalistener/metrics/metrics.go b/cmd/ghalistener/metrics/metrics.go new file mode 100644 index 00000000..42618387 --- /dev/null +++ b/cmd/ghalistener/metrics/metrics.go @@ -0,0 +1,387 @@ +package metrics + +import ( + "context" + "net/http" + "strconv" + + "github.com/actions/actions-runner-controller/github/actions" + "github.com/go-logr/logr" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +const ( + labelKeyRunnerScaleSetName = "name" + labelKeyRunnerScaleSetNamespace = "namespace" + labelKeyEnterprise = "enterprise" + labelKeyOrganization = "organization" + labelKeyRepository = "repository" + labelKeyJobName = "job_name" + labelKeyJobWorkflowRef = "job_workflow_ref" + labelKeyEventName = "event_name" + labelKeyJobResult = "job_result" + labelKeyRunnerID = "runner_id" + labelKeyRunnerName = "runner_name" +) + +const githubScaleSetSubsystem = "gha" + +// labels +var ( + scaleSetLabels = []string{ + labelKeyRunnerScaleSetName, + labelKeyRepository, + labelKeyOrganization, + labelKeyEnterprise, + labelKeyRunnerScaleSetNamespace, + } + + jobLabels = []string{ + labelKeyRepository, + labelKeyOrganization, + labelKeyEnterprise, + labelKeyJobName, + labelKeyJobWorkflowRef, + labelKeyEventName, + } + + completedJobsTotalLabels = append(jobLabels, labelKeyJobResult, labelKeyRunnerID, labelKeyRunnerName) + jobExecutionDurationLabels = append(jobLabels, labelKeyJobResult, labelKeyRunnerID, labelKeyRunnerName) + startedJobsTotalLabels = append(jobLabels, labelKeyRunnerID, labelKeyRunnerName) + jobStartupDurationLabels = append(jobLabels, labelKeyRunnerID, labelKeyRunnerName) +) + +var ( + assignedJobs = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: githubScaleSetSubsystem, + Name: "assigned_jobs", + Help: "Number of jobs assigned to this scale set.", + }, + scaleSetLabels, + ) + + runningJobs = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: githubScaleSetSubsystem, + Name: "running_jobs", + Help: "Number of jobs running (or about to be run).", + }, + scaleSetLabels, + ) + + registeredRunners = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: githubScaleSetSubsystem, + Name: "registered_runners", + Help: "Number of runners registered by the scale set.", + }, + scaleSetLabels, + ) + + busyRunners = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: githubScaleSetSubsystem, + Name: "busy_runners", + Help: "Number of registered runners running a job.", + }, + scaleSetLabels, + ) + + minRunners = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: githubScaleSetSubsystem, + Name: "min_runners", + Help: "Minimum number of runners.", + }, + scaleSetLabels, + ) + + maxRunners = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: githubScaleSetSubsystem, + Name: "max_runners", + Help: "Maximum number of runners.", + }, + scaleSetLabels, + ) + + desiredRunners = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: githubScaleSetSubsystem, + Name: "desired_runners", + Help: "Number of runners desired by the scale set.", + }, + scaleSetLabels, + ) + + idleRunners = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Subsystem: githubScaleSetSubsystem, + Name: "idle_runners", + Help: "Number of registered runners not running a job.", + }, + scaleSetLabels, + ) + + startedJobsTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Subsystem: githubScaleSetSubsystem, + Name: "started_jobs_total", + Help: "Total number of jobs started.", + }, + startedJobsTotalLabels, + ) + + completedJobsTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "completed_jobs_total", + Help: "Total number of jobs completed.", + Subsystem: githubScaleSetSubsystem, + }, + completedJobsTotalLabels, + ) + + jobStartupDurationSeconds = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: githubScaleSetSubsystem, + Name: "job_startup_duration_seconds", + Help: "Time spent waiting for workflow job to get started on the runner owned by the scale set (in seconds).", + Buckets: runtimeBuckets, + }, + jobStartupDurationLabels, + ) + + jobExecutionDurationSeconds = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Subsystem: githubScaleSetSubsystem, + Name: "job_execution_duration_seconds", + Help: "Time spent executing workflow jobs by the scale set (in seconds).", + Buckets: runtimeBuckets, + }, + jobExecutionDurationLabels, + ) +) + +var runtimeBuckets []float64 = []float64{ + 0.01, + 0.05, + 0.1, + 0.5, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 12, + 15, + 18, + 20, + 25, + 30, + 40, + 50, + 60, + 70, + 80, + 90, + 100, + 110, + 120, + 150, + 180, + 210, + 240, + 300, + 360, + 420, + 480, + 540, + 600, + 900, + 1200, + 1800, + 2400, + 3000, + 3600, +} + +type baseLabels struct { + scaleSetName string + scaleSetNamespace string + enterprise string + organization string + repository string +} + +func (b *baseLabels) jobLabels(jobBase *actions.JobMessageBase) prometheus.Labels { + return prometheus.Labels{ + labelKeyEnterprise: b.enterprise, + labelKeyOrganization: b.organization, + labelKeyRepository: b.repository, + labelKeyJobName: jobBase.JobDisplayName, + labelKeyJobWorkflowRef: jobBase.JobWorkflowRef, + labelKeyEventName: jobBase.EventName, + } +} + +func (b *baseLabels) scaleSetLabels() prometheus.Labels { + return prometheus.Labels{ + labelKeyRunnerScaleSetName: b.scaleSetName, + labelKeyRunnerScaleSetNamespace: b.scaleSetNamespace, + labelKeyEnterprise: b.enterprise, + labelKeyOrganization: b.organization, + labelKeyRepository: b.repository, + } +} + +func (b *baseLabels) completedJobLabels(msg *actions.JobCompleted) prometheus.Labels { + l := b.jobLabels(&msg.JobMessageBase) + l[labelKeyRunnerID] = strconv.Itoa(msg.RunnerId) + l[labelKeyJobResult] = msg.Result + l[labelKeyRunnerName] = msg.RunnerName + return l +} + +func (b *baseLabels) startedJobLabels(msg *actions.JobStarted) prometheus.Labels { + l := b.jobLabels(&msg.JobMessageBase) + l[labelKeyRunnerID] = strconv.Itoa(msg.RunnerId) + l[labelKeyRunnerName] = msg.RunnerName + return l +} + +//go:generate mockery --name Publisher --output ./mocks --outpkg mocks --case underscore +type Publisher interface { + PublishStatic(min, max int) + PublishStatistics(stats *actions.RunnerScaleSetStatistic) + PublishJobStarted(msg *actions.JobStarted) + PublishJobCompleted(msg *actions.JobCompleted) + PublishDesiredRunners(count int) +} + +//go:generate mockery --name ServerPublisher --output ./mocks --outpkg mocks --case underscore +type ServerPublisher interface { + Publisher + ListenAndServe(ctx context.Context) error +} + +var _ Publisher = &discard{} +var _ ServerPublisher = &exporter{} + +var Discard Publisher = &discard{} + +type exporter struct { + logger logr.Logger + baseLabels + srv *http.Server +} + +type ExporterConfig struct { + ScaleSetName string + ScaleSetNamespace string + Enterprise string + Organization string + Repository string + ServerAddr string + ServerEndpoint string + Logger logr.Logger +} + +func NewExporter(config ExporterConfig) ServerPublisher { + reg := prometheus.NewRegistry() + reg.MustRegister( + assignedJobs, + runningJobs, + registeredRunners, + busyRunners, + minRunners, + maxRunners, + desiredRunners, + idleRunners, + startedJobsTotal, + completedJobsTotal, + jobStartupDurationSeconds, + jobExecutionDurationSeconds, + ) + + mux := http.NewServeMux() + mux.Handle( + config.ServerEndpoint, + promhttp.HandlerFor(reg, promhttp.HandlerOpts{Registry: reg}), + ) + + return &exporter{ + logger: config.Logger.WithName("metrics"), + baseLabels: baseLabels{ + scaleSetName: config.ScaleSetName, + scaleSetNamespace: config.ScaleSetNamespace, + enterprise: config.Enterprise, + organization: config.Organization, + repository: config.Repository, + }, + srv: &http.Server{ + Addr: config.ServerAddr, + Handler: mux, + }, + } +} + +func (e *exporter) ListenAndServe(ctx context.Context) error { + e.logger.Info("starting metrics server", "addr", e.srv.Addr) + go func() { + <-ctx.Done() + e.logger.Info("stopping metrics server") + e.srv.Shutdown(ctx) + }() + return e.srv.ListenAndServe() +} + +func (m *exporter) PublishStatic(min, max int) { + l := m.scaleSetLabels() + maxRunners.With(l).Set(float64(max)) + minRunners.With(l).Set(float64(min)) +} + +func (e *exporter) PublishStatistics(stats *actions.RunnerScaleSetStatistic) { + l := e.scaleSetLabels() + + assignedJobs.With(l).Set(float64(stats.TotalAssignedJobs)) + runningJobs.With(l).Set(float64(stats.TotalRunningJobs)) + registeredRunners.With(l).Set(float64(stats.TotalRegisteredRunners)) + busyRunners.With(l).Set(float64(stats.TotalBusyRunners)) + idleRunners.With(l).Set(float64(stats.TotalIdleRunners)) +} + +func (e *exporter) PublishJobStarted(msg *actions.JobStarted) { + l := e.startedJobLabels(msg) + startedJobsTotal.With(l).Inc() + + startupDuration := msg.JobMessageBase.RunnerAssignTime.Unix() - msg.JobMessageBase.ScaleSetAssignTime.Unix() + jobStartupDurationSeconds.With(l).Observe(float64(startupDuration)) +} + +func (e *exporter) PublishJobCompleted(msg *actions.JobCompleted) { + l := e.completedJobLabels(msg) + completedJobsTotal.With(l).Inc() + + executionDuration := msg.JobMessageBase.FinishTime.Unix() - msg.JobMessageBase.RunnerAssignTime.Unix() + jobExecutionDurationSeconds.With(l).Observe(float64(executionDuration)) +} + +func (m *exporter) PublishDesiredRunners(count int) { + desiredRunners.With(m.scaleSetLabels()).Set(float64(count)) +} + +type discard struct{} + +func (*discard) PublishStatic(int, int) {} +func (*discard) PublishStatistics(*actions.RunnerScaleSetStatistic) {} +func (*discard) PublishJobStarted(*actions.JobStarted) {} +func (*discard) PublishJobCompleted(*actions.JobCompleted) {} +func (*discard) PublishDesiredRunners(int) {} diff --git a/cmd/ghalistener/metrics/mocks/publisher.go b/cmd/ghalistener/metrics/mocks/publisher.go new file mode 100644 index 00000000..08858594 --- /dev/null +++ b/cmd/ghalistener/metrics/mocks/publisher.go @@ -0,0 +1,53 @@ +// Code generated by mockery v2.36.1. DO NOT EDIT. + +package mocks + +import ( + actions "github.com/actions/actions-runner-controller/github/actions" + + mock "github.com/stretchr/testify/mock" +) + +// Publisher is an autogenerated mock type for the Publisher type +type Publisher struct { + mock.Mock +} + +// PublishDesiredRunners provides a mock function with given fields: count +func (_m *Publisher) PublishDesiredRunners(count int) { + _m.Called(count) +} + +// PublishJobCompleted provides a mock function with given fields: msg +func (_m *Publisher) PublishJobCompleted(msg *actions.JobCompleted) { + _m.Called(msg) +} + +// PublishJobStarted provides a mock function with given fields: msg +func (_m *Publisher) PublishJobStarted(msg *actions.JobStarted) { + _m.Called(msg) +} + +// PublishStatic provides a mock function with given fields: min, max +func (_m *Publisher) PublishStatic(min int, max int) { + _m.Called(min, max) +} + +// PublishStatistics provides a mock function with given fields: stats +func (_m *Publisher) PublishStatistics(stats *actions.RunnerScaleSetStatistic) { + _m.Called(stats) +} + +// NewPublisher creates a new instance of Publisher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewPublisher(t interface { + mock.TestingT + Cleanup(func()) +}) *Publisher { + mock := &Publisher{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/cmd/ghalistener/metrics/mocks/server_publisher.go b/cmd/ghalistener/metrics/mocks/server_publisher.go new file mode 100644 index 00000000..01aac02e --- /dev/null +++ b/cmd/ghalistener/metrics/mocks/server_publisher.go @@ -0,0 +1,69 @@ +// Code generated by mockery v2.36.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + actions "github.com/actions/actions-runner-controller/github/actions" + + mock "github.com/stretchr/testify/mock" +) + +// ServerPublisher is an autogenerated mock type for the ServerPublisher type +type ServerPublisher struct { + mock.Mock +} + +// ListenAndServe provides a mock function with given fields: ctx +func (_m *ServerPublisher) ListenAndServe(ctx context.Context) error { + ret := _m.Called(ctx) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// PublishDesiredRunners provides a mock function with given fields: count +func (_m *ServerPublisher) PublishDesiredRunners(count int) { + _m.Called(count) +} + +// PublishJobCompleted provides a mock function with given fields: msg +func (_m *ServerPublisher) PublishJobCompleted(msg *actions.JobCompleted) { + _m.Called(msg) +} + +// PublishJobStarted provides a mock function with given fields: msg +func (_m *ServerPublisher) PublishJobStarted(msg *actions.JobStarted) { + _m.Called(msg) +} + +// PublishStatic provides a mock function with given fields: min, max +func (_m *ServerPublisher) PublishStatic(min int, max int) { + _m.Called(min, max) +} + +// PublishStatistics provides a mock function with given fields: stats +func (_m *ServerPublisher) PublishStatistics(stats *actions.RunnerScaleSetStatistic) { + _m.Called(stats) +} + +// NewServerPublisher creates a new instance of ServerPublisher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewServerPublisher(t interface { + mock.TestingT + Cleanup(func()) +}) *ServerPublisher { + mock := &ServerPublisher{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/cmd/ghalistener/worker/worker.go b/cmd/ghalistener/worker/worker.go new file mode 100644 index 00000000..12fd4d79 --- /dev/null +++ b/cmd/ghalistener/worker/worker.go @@ -0,0 +1,228 @@ +package worker + +import ( + "context" + "encoding/json" + "fmt" + "math" + + "github.com/actions/actions-runner-controller/apis/actions.github.com/v1alpha1" + "github.com/actions/actions-runner-controller/cmd/ghalistener/listener" + "github.com/actions/actions-runner-controller/github/actions" + "github.com/actions/actions-runner-controller/logging" + jsonpatch "github.com/evanphx/json-patch" + "github.com/go-logr/logr" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" +) + +const workerName = "kubernetesworker" + +type Option func(*Worker) + +func WithLogger(logger logr.Logger) Option { + return func(w *Worker) { + logger = logger.WithName(workerName) + w.logger = &logger + } +} + +type Config struct { + EphemeralRunnerSetNamespace string + EphemeralRunnerSetName string + MaxRunners int + MinRunners int +} + +// The Worker's role is to process the messages it receives from the listener. +// It then initiates Kubernetes API requests to carry out the necessary actions. +type Worker struct { + clientset *kubernetes.Clientset + config Config + lastPatch int + logger *logr.Logger +} + +var _ listener.Handler = (*Worker)(nil) + +func New(config Config, options ...Option) (*Worker, error) { + w := &Worker{ + config: config, + lastPatch: -1, + } + + conf, err := rest.InClusterConfig() + if err != nil { + return nil, err + } + + clientset, err := kubernetes.NewForConfig(conf) + if err != nil { + return nil, err + } + + w.clientset = clientset + + for _, option := range options { + option(w) + } + + if err := w.applyDefaults(); err != nil { + return nil, err + } + + return w, nil +} + +func (w *Worker) applyDefaults() error { + if w.logger == nil { + logger, err := logging.NewLogger(logging.LogLevelDebug, logging.LogFormatJSON) + if err != nil { + return fmt.Errorf("NewLogger failed: %w", err) + } + logger = logger.WithName(workerName) + w.logger = &logger + } + + return nil +} + +// HandleJobStarted updates the job information for the ephemeral runner when a job is started. +// It takes a context and a jobInfo parameter which contains the details of the started job. +// This update marks the ephemeral runner so that the controller would have more context +// about the ephemeral runner that should not be deleted when scaling down. +// It returns an error if there is any issue with updating the job information. +func (w *Worker) HandleJobStarted(ctx context.Context, jobInfo *actions.JobStarted) error { + w.logger.Info("Updating job info for the runner", + "runnerName", jobInfo.RunnerName, + "ownerName", jobInfo.OwnerName, + "repoName", jobInfo.RepositoryName, + "workflowRef", jobInfo.JobWorkflowRef, + "workflowRunId", jobInfo.WorkflowRunId, + "jobDisplayName", jobInfo.JobDisplayName, + "requestId", jobInfo.RunnerRequestId) + + original, err := json.Marshal(&v1alpha1.EphemeralRunner{}) + if err != nil { + return fmt.Errorf("failed to marshal empty ephemeral runner: %w", err) + } + + patch, err := json.Marshal( + &v1alpha1.EphemeralRunner{ + Status: v1alpha1.EphemeralRunnerStatus{ + JobRequestId: jobInfo.RunnerRequestId, + JobRepositoryName: fmt.Sprintf("%s/%s", jobInfo.OwnerName, jobInfo.RepositoryName), + WorkflowRunId: jobInfo.WorkflowRunId, + JobWorkflowRef: jobInfo.JobWorkflowRef, + JobDisplayName: jobInfo.JobDisplayName, + }, + }, + ) + if err != nil { + return fmt.Errorf("failed to marshal ephemeral runner patch: %w", err) + } + + mergePatch, err := jsonpatch.CreateMergePatch(original, patch) + if err != nil { + return fmt.Errorf("failed to create merge patch json for ephemeral runner: %w", err) + } + + w.logger.Info("Updating ephemeral runner with merge patch", "json", string(mergePatch)) + + patchedStatus := &v1alpha1.EphemeralRunner{} + err = w.clientset.RESTClient(). + Patch(types.MergePatchType). + Prefix("apis", v1alpha1.GroupVersion.Group, v1alpha1.GroupVersion.Version). + Namespace(w.config.EphemeralRunnerSetNamespace). + Resource("EphemeralRunners"). + Name(jobInfo.RunnerName). + SubResource("status"). + Body(mergePatch). + Do(ctx). + Into(patchedStatus) + if err != nil { + return fmt.Errorf("could not patch ephemeral runner status, patch JSON: %s, error: %w", string(mergePatch), err) + } + + w.logger.Info("Ephemeral runner status updated with the merge patch successfully.") + + return nil +} + +// HandleDesiredRunnerCount handles the desired runner count by scaling the ephemeral runner set. +// The function calculates the target runner count based on the minimum and maximum runner count configuration. +// If the target runner count is the same as the last patched count, it skips patching and returns nil. +// Otherwise, it creates a merge patch JSON for updating the ephemeral runner set with the desired count. +// The function then scales the ephemeral runner set by applying the merge patch. +// Finally, it logs the scaled ephemeral runner set details and returns nil if successful. +// If any error occurs during the process, it returns an error with a descriptive message. +func (w *Worker) HandleDesiredRunnerCount(ctx context.Context, count int) error { + targetRunnerCount := int(math.Max(math.Min(float64(w.config.MaxRunners), float64(count)), float64(w.config.MinRunners))) + + logValues := []any{ + "assigned job", count, + "decision", targetRunnerCount, + "min", w.config.MinRunners, + "max", w.config.MaxRunners, + "currentRunnerCount", w.lastPatch, + } + + if targetRunnerCount == w.lastPatch { + w.logger.Info("Skipping patching of EphemeralRunnerSet as the desired count has not changed", logValues...) + return nil + } + + original, err := json.Marshal( + &v1alpha1.EphemeralRunnerSet{ + Spec: v1alpha1.EphemeralRunnerSetSpec{ + Replicas: -1, + }, + }, + ) + if err != nil { + return fmt.Errorf("failed to marshal empty ephemeral runner set: %w", err) + } + + patch, err := json.Marshal( + &v1alpha1.EphemeralRunnerSet{ + Spec: v1alpha1.EphemeralRunnerSetSpec{ + Replicas: count, + }, + }, + ) + if err != nil { + w.logger.Error(err, "could not marshal patch ephemeral runner set") + return err + } + + mergePatch, err := jsonpatch.CreateMergePatch(original, patch) + if err != nil { + return fmt.Errorf("failed to create merge patch json for ephemeral runner set: %w", err) + } + + w.logger.Info("Created merge patch json for EphemeralRunnerSet update", "json", string(mergePatch)) + + w.logger.Info("Scaling ephemeral runner set", logValues...) + + patchedEphemeralRunnerSet := &v1alpha1.EphemeralRunnerSet{} + err = w.clientset.RESTClient(). + Patch(types.MergePatchType). + Prefix("apis", v1alpha1.GroupVersion.Group, v1alpha1.GroupVersion.Version). + Namespace(w.config.EphemeralRunnerSetNamespace). + Resource("ephemeralrunnersets"). + Name(w.config.EphemeralRunnerSetName). + Body([]byte(mergePatch)). + Do(ctx). + Into(patchedEphemeralRunnerSet) + if err != nil { + return fmt.Errorf("could not patch ephemeral runner set , patch JSON: %s, error: %w", string(mergePatch), err) + } + + w.logger.Info("Ephemeral runner set scaled.", + "namespace", w.config.EphemeralRunnerSetNamespace, + "name", w.config.EphemeralRunnerSetName, + "replicas", patchedEphemeralRunnerSet.Spec.Replicas, + ) + return nil +} diff --git a/cmd/githubrunnerscalesetlistener/mock_KubernetesManager.go b/cmd/githubrunnerscalesetlistener/mock_KubernetesManager.go index 680a47c3..8c44598c 100644 --- a/cmd/githubrunnerscalesetlistener/mock_KubernetesManager.go +++ b/cmd/githubrunnerscalesetlistener/mock_KubernetesManager.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.33.2. DO NOT EDIT. +// Code generated by mockery v2.36.1. DO NOT EDIT. package main diff --git a/cmd/githubrunnerscalesetlistener/mock_RunnerScaleSetClient.go b/cmd/githubrunnerscalesetlistener/mock_RunnerScaleSetClient.go index 929ad7bf..80ba900a 100644 --- a/cmd/githubrunnerscalesetlistener/mock_RunnerScaleSetClient.go +++ b/cmd/githubrunnerscalesetlistener/mock_RunnerScaleSetClient.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.33.2. DO NOT EDIT. +// Code generated by mockery v2.36.1. DO NOT EDIT. package main diff --git a/controllers/actions.github.com/resourcebuilder.go b/controllers/actions.github.com/resourcebuilder.go index 895d9b56..c306fb88 100644 --- a/controllers/actions.github.com/resourcebuilder.go +++ b/controllers/actions.github.com/resourcebuilder.go @@ -38,8 +38,11 @@ var commonLabelKeys = [...]string{ const labelValueKubernetesPartOf = "gha-runner-scale-set" -var scaleSetListenerLogLevel = DefaultScaleSetListenerLogLevel -var scaleSetListenerLogFormat = DefaultScaleSetListenerLogFormat +var ( + scaleSetListenerLogLevel = DefaultScaleSetListenerLogLevel + scaleSetListenerLogFormat = DefaultScaleSetListenerLogFormat + scaleSetListenerEntrypoint = "/ghalistener" +) func SetListenerLoggingParameters(level string, format string) bool { switch level { @@ -59,6 +62,12 @@ func SetListenerLoggingParameters(level string, format string) bool { return true } +func SetListenerEntrypoint(entrypoint string) { + if entrypoint != "" { + scaleSetListenerEntrypoint = entrypoint + } +} + type resourceBuilder struct{} func (b *resourceBuilder) newAutoScalingListener(autoscalingRunnerSet *v1alpha1.AutoscalingRunnerSet, ephemeralRunnerSet *v1alpha1.EphemeralRunnerSet, namespace, image string, imagePullSecrets []corev1.LocalObjectReference) (*v1alpha1.AutoscalingListener, error) { @@ -225,7 +234,7 @@ func (b *resourceBuilder) newScaleSetListenerPod(autoscalingListener *v1alpha1.A Image: autoscalingListener.Spec.Image, Env: listenerEnv, Command: []string{ - "/github-runnerscaleset-listener", + scaleSetListenerEntrypoint, }, Ports: ports, VolumeMounts: []corev1.VolumeMount{ diff --git a/github/actions/mock_ActionsService.go b/github/actions/mock_ActionsService.go index b4a25df9..0216cf30 100644 --- a/github/actions/mock_ActionsService.go +++ b/github/actions/mock_ActionsService.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.33.2. DO NOT EDIT. +// Code generated by mockery v2.36.1. DO NOT EDIT. package actions diff --git a/github/actions/mock_SessionService.go b/github/actions/mock_SessionService.go index e4b06ec0..ed403eee 100644 --- a/github/actions/mock_SessionService.go +++ b/github/actions/mock_SessionService.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.33.2. DO NOT EDIT. +// Code generated by mockery v2.36.1. DO NOT EDIT. package actions diff --git a/main.go b/main.go index a90791e9..284920da 100644 --- a/main.go +++ b/main.go @@ -203,6 +203,8 @@ func main() { log.Info("Using default AutoscalingListener logging parameters", "LogLevel", actionsgithubcom.DefaultScaleSetListenerLogLevel, "LogFormat", actionsgithubcom.DefaultScaleSetListenerLogFormat) } + actionsgithubcom.SetListenerEntrypoint(os.Getenv("LISTENER_ENTRYPOINT")) + var webhookServer webhook.Server if port != 0 { webhookServer = webhook.NewServer(webhook.Options{