diff --git a/github/fake/fake.go b/github/fake/fake.go new file mode 100644 index 00000000..a9ed3667 --- /dev/null +++ b/github/fake/fake.go @@ -0,0 +1,86 @@ +package fake + +import ( + "fmt" + "net/http" + "net/http/httptest" + "time" +) + +const ( + RegistrationToken = "fake-registration-token" + + RunnersListBody = ` +{ + "total_count": 2, + "runners": [ + {"id": 1, "name": "test1", "os": "linux", "status": "online"}, + {"id": 2, "name": "test2", "os": "linux", "status": "offline"} + ] +} +` +) + +type handler struct { + Status int + Body string +} + +func (h *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(h.Status) + fmt.Fprintf(w, h.Body) +} + +func NewServer() *httptest.Server { + routes := map[string]handler{ + // For CreateRegistrationToken + "/repos/test/valid/actions/runners/registration-token": handler{ + Status: http.StatusCreated, + Body: fmt.Sprintf("{\"token\": \"%s\", \"expires_at\": \"%s\"}", RegistrationToken, time.Now().Add(time.Hour*1).Format(time.RFC3339)), + }, + "/repos/test/invalid/actions/runners/registration-token": handler{ + Status: http.StatusOK, + Body: fmt.Sprintf("{\"token\": \"%s\", \"expires_at\": \"%s\"}", RegistrationToken, time.Now().Add(time.Hour*1).Format(time.RFC3339)), + }, + "/repos/test/error/actions/runners/registration-token": handler{ + Status: http.StatusBadRequest, + Body: "", + }, + + // For ListRunners + "/repos/test/valid/actions/runners": handler{ + Status: http.StatusOK, + Body: RunnersListBody, + }, + "/repos/test/invalid/actions/runners": handler{ + Status: http.StatusNoContent, + Body: "", + }, + "/repos/test/error/actions/runners": handler{ + Status: http.StatusBadRequest, + Body: "", + }, + + // For RemoveRunner + "/repos/test/valid/actions/runners/1": handler{ + Status: http.StatusNoContent, + Body: "", + }, + "/repos/test/invalid/actions/runners/1": handler{ + Status: http.StatusOK, + Body: "", + }, + "/repos/test/error/actions/runners/1": handler{ + Status: http.StatusBadRequest, + Body: "", + }, + } + + mux := http.NewServeMux() + for path, handler := range routes { + h := handler + mux.Handle(path, &h) + } + + return httptest.NewServer(mux) +} diff --git a/github/github.go b/github/github.go new file mode 100644 index 00000000..401eee4f --- /dev/null +++ b/github/github.go @@ -0,0 +1,147 @@ +package github + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/bradleyfalzon/ghinstallation" + "github.com/google/go-github/v31/github" + "golang.org/x/oauth2" +) + +type Client struct { + *github.Client + regTokens map[string]*github.RegistrationToken + mu sync.Mutex +} + +// NewClient returns a client authenticated as a GitHub App. +func NewClient(appID, installationID int64, privateKeyPath string) (*Client, error) { + tr, err := ghinstallation.NewKeyFromFile(http.DefaultTransport, appID, installationID, privateKeyPath) + if err != nil { + return nil, fmt.Errorf("authentication failed: %v", err) + } + + return &Client{ + Client: github.NewClient(&http.Client{Transport: tr}), + regTokens: map[string]*github.RegistrationToken{}, + mu: sync.Mutex{}, + }, nil +} + +// NewClient returns a client authenticated with personal access token. +func NewClientWithAccessToken(token string) (*Client, error) { + tc := oauth2.NewClient(context.Background(), oauth2.StaticTokenSource( + &oauth2.Token{AccessToken: token}, + )) + + return &Client{ + Client: github.NewClient(tc), + regTokens: map[string]*github.RegistrationToken{}, + mu: sync.Mutex{}, + }, nil +} + +// GetRegistrationToken returns a registration token tied with the name of repository and runner. +func (c *Client) GetRegistrationToken(ctx context.Context, repository, name string) (*github.RegistrationToken, error) { + c.mu.Lock() + defer c.mu.Unlock() + + owner, repo, err := splitOwnerAndRepo(repository) + if err != nil { + return nil, err + } + + key := fmt.Sprintf("%s/%s", repo, name) + rt, ok := c.regTokens[key] + if ok && rt.GetExpiresAt().After(time.Now().Add(-10*time.Minute)) { + return rt, nil + } + + rt, res, err := c.Client.Actions.CreateRegistrationToken(ctx, owner, repo) + if err != nil { + return nil, fmt.Errorf("failed to create registration token: %v", err) + } + + if res.StatusCode != 201 { + return nil, fmt.Errorf("unexpected status: %d", res.StatusCode) + } + + c.regTokens[key] = rt + go func() { + c.cleanup() + }() + + return rt, nil +} + +// RemoveRunner removes a runner with specified runner ID from repocitory. +func (c *Client) RemoveRunner(ctx context.Context, repository string, runnerID int64) error { + owner, repo, err := splitOwnerAndRepo(repository) + if err != nil { + return err + } + + res, err := c.Client.Actions.RemoveRunner(ctx, owner, repo, runnerID) + if err != nil { + return fmt.Errorf("failed to remove runner: %v", err) + } + + if res.StatusCode != 204 { + return fmt.Errorf("unexpected status: %d", res.StatusCode) + } + + return nil +} + +// ListRunners returns a list of runners of specified repository name. +func (c *Client) ListRunners(ctx context.Context, repository string) ([]*github.Runner, error) { + var runners []*github.Runner + + owner, repo, err := splitOwnerAndRepo(repository) + if err != nil { + return runners, err + } + + opts := github.ListOptions{PerPage: 10} + for { + list, res, err := c.Client.Actions.ListRunners(ctx, owner, repo, &opts) + if err != nil { + return runners, fmt.Errorf("failed to remove runner: %v", err) + } + + runners = append(runners, list.Runners...) + if res.NextPage == 0 { + break + } + opts.Page = res.NextPage + } + + return runners, nil +} + +// cleanup removes expired registration tokens. +func (c *Client) cleanup() { + c.mu.Lock() + defer c.mu.Unlock() + + for key, rt := range c.regTokens { + if rt.GetExpiresAt().Before(time.Now()) { + delete(c.regTokens, key) + } + } +} + +// splitOwnerAndRepo splits specified repository name to the owner and repo name. +func splitOwnerAndRepo(repo string) (string, string, error) { + chunk := strings.Split(repo, "/") + if len(chunk) != 2 { + return "", "", errors.New("invalid repository name") + } + return chunk[0], chunk[1], nil +} diff --git a/github/github_test.go b/github/github_test.go new file mode 100644 index 00000000..bd048efa --- /dev/null +++ b/github/github_test.go @@ -0,0 +1,124 @@ +package github + +import ( + "context" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/google/go-github/v31/github" + "github.com/summerwind/actions-runner-controller/github/fake" +) + +var server *httptest.Server + +func newTestClient() *Client { + client, err := NewClientWithAccessToken("token") + if err != nil { + panic(err) + } + + baseURL, err := url.Parse(server.URL + "/") + if err != nil { + panic(err) + } + client.Client.BaseURL = baseURL + + return client +} + +func TestMain(m *testing.M) { + server = fake.NewServer() + defer server.Close() + m.Run() +} + +func TestGetRegistrationToken(t *testing.T) { + tests := []struct { + repo string + token string + err bool + }{ + {repo: "test/valid", token: fake.RegistrationToken, err: false}, + {repo: "test/invalid", token: "", err: true}, + {repo: "test/error", token: "", err: true}, + } + + client := newTestClient() + for i, tt := range tests { + rt, err := client.GetRegistrationToken(context.Background(), tt.repo, "test") + if !tt.err && err != nil { + t.Errorf("[%d] unexpected error: %v", i, err) + } + if tt.token != rt.GetToken() { + t.Errorf("[%d] unexpected token: %v", i, rt.GetToken()) + } + } +} + +func TestListRunners(t *testing.T) { + tests := []struct { + repo string + length int + err bool + }{ + {repo: "test/valid", length: 2, err: false}, + {repo: "test/invalid", length: 0, err: true}, + {repo: "test/error", length: 0, err: true}, + } + + client := newTestClient() + for i, tt := range tests { + runners, err := client.ListRunners(context.Background(), tt.repo) + if !tt.err && err != nil { + t.Errorf("[%d] unexpected error: %v", i, err) + } + if tt.length != len(runners) { + t.Errorf("[%d] unexpected runners list: %v", i, runners) + } + } +} + +func TestRemoveRunner(t *testing.T) { + tests := []struct { + repo string + err bool + }{ + {repo: "test/valid", err: false}, + {repo: "test/invalid", err: true}, + {repo: "test/error", err: true}, + } + + client := newTestClient() + for i, tt := range tests { + err := client.RemoveRunner(context.Background(), tt.repo, int64(1)) + if !tt.err && err != nil { + t.Errorf("[%d] unexpected error: %v", i, err) + } + } +} + +func TestCleanup(t *testing.T) { + token := "token" + + client := newTestClient() + client.regTokens = map[string]*github.RegistrationToken{ + "active": &github.RegistrationToken{ + Token: &token, + ExpiresAt: &github.Timestamp{Time: time.Now().Add(time.Hour * 1)}, + }, + "expired": &github.RegistrationToken{ + Token: &token, + ExpiresAt: &github.Timestamp{Time: time.Now().Add(-time.Hour * 1)}, + }, + } + + client.cleanup() + if _, ok := client.regTokens["active"]; !ok { + t.Errorf("active token was accidentally removed") + } + if _, ok := client.regTokens["expired"]; ok { + t.Errorf("expired token still exists") + } +}