actions-runner-controller/pkg/hookdeliveryforwarder/forwarder.go

254 lines
4.8 KiB
Go

package hookdeliveryforwarder
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"os"
"sort"
"strings"
"time"
"github.com/actions/actions-runner-controller/github"
gogithub "github.com/google/go-github/v52/github"
)
type Forwarder struct {
Repo string
Target string
Hook gogithub.Hook
PollingDelay time.Duration
Client *github.Client
Checkpointer Checkpointer
logger
}
type persistentError struct {
Err error
}
func (e persistentError) Error() string {
return fmt.Sprintf("%v", e.Err)
}
func (f *Forwarder) Run(ctx context.Context) error {
pollingDelay := 10 * time.Second
if f.PollingDelay > 0 {
pollingDelay = f.PollingDelay
}
segments := strings.Split(f.Repo, "/")
owner := segments[0]
var repo string
if len(segments) > 1 {
repo = segments[1]
}
hooksAPI := newHooksAPI(f.Client.Client, owner, repo)
hooks, _, err := hooksAPI.ListHooks(ctx, nil)
if err != nil {
f.Errorf("Failed listing hooks: %v", err)
return err
}
var hook *gogithub.Hook
for i := range hooks {
hook = hooks[i]
break
}
if hook == nil {
hookConfig := &f.Hook
if _, ok := hookConfig.Config["url"]; !ok {
return persistentError{Err: fmt.Errorf("config.url is missing in the hook config")}
}
if _, ok := hookConfig.Config["content_type"]; !ok {
hookConfig.Config["content_type"] = "json"
}
if _, ok := hookConfig.Config["insecure_ssl"]; !ok {
hookConfig.Config["insecure_ssl"] = 0
}
if _, ok := hookConfig.Config["secret"]; !ok {
hookConfig.Config["secret"] = os.Getenv("GITHUB_HOOK_SECRET")
}
if len(hookConfig.Events) == 0 {
hookConfig.Events = []string{"check_run", "push"}
}
if hookConfig.Active == nil {
hookConfig.Active = gogithub.Bool(true)
}
h, _, err := hooksAPI.CreateHook(ctx, hookConfig)
if err != nil {
f.Errorf("Failed creating hook: %v", err)
return persistentError{Err: err}
}
hook = h
}
f.Logf("Using this hook for receiving deliveries to be forwarded: %+v", *hook)
hookDeliveries := newHookDeliveriesAPI(f.Client.Client, owner, repo, hook.GetID())
cur, err := f.Checkpointer.GetOrCreate(hook.GetID())
if err != nil {
f.Errorf("Failed to get or create log position: %v", err)
return persistentError{Err: err}
}
LOOP:
for {
var (
err error
payloads [][]byte
)
payloads, cur, err = f.getUnprocessedDeliveries(ctx, hookDeliveries, *cur)
if err != nil {
f.Errorf("failed getting unprocessed deliveries: %v", err)
if errors.Is(err, context.Canceled) {
return err
}
}
for _, p := range payloads {
if _, err := http.Post(f.Target, "application/json", bytes.NewReader(p)); err != nil {
f.Errorf("failed forwarding delivery: %v", err)
retryDelay := 5 * time.Second
t := time.NewTimer(retryDelay)
select {
case <-t.C:
t.Stop()
case <-ctx.Done():
t.Stop()
return ctx.Err()
}
continue LOOP
} else {
f.Logf("Successfully POSTed the payload to %s", f.Target)
}
}
if err := f.Checkpointer.Update(hook.GetID(), cur); err != nil {
return fmt.Errorf("failed updating checkpoint: %w", err)
}
t := time.NewTimer(pollingDelay)
select {
case <-t.C:
t.Stop()
case <-ctx.Done():
t.Stop()
return ctx.Err()
}
}
}
type State struct {
DeliveredAt time.Time
ID int64
}
func (f *Forwarder) getUnprocessedDeliveries(ctx context.Context, hookDeliveries *hookDeliveriesAPI, pos State) ([][]byte, *State, error) {
var (
opts gogithub.ListCursorOptions
)
opts.PerPage = 2
var deliveries []*gogithub.HookDelivery
OUTER:
for {
ds, resp, err := hookDeliveries.ListHookDeliveries(ctx, &opts)
if err != nil {
return nil, nil, err
}
opts.Cursor = resp.Cursor
for _, d := range ds {
d, _, err := hookDeliveries.GetHookDelivery(ctx, d.GetID())
if err != nil {
return nil, nil, err
}
payload, err := d.ParseRequestPayload()
if err != nil {
return nil, nil, err
}
id := d.GetID()
deliveredAt := d.GetDeliveredAt()
if !pos.DeliveredAt.IsZero() && deliveredAt.Before(pos.DeliveredAt) {
f.Logf("%s is before %s so skipping all the remaining deliveries", deliveredAt, pos.DeliveredAt)
break OUTER
}
if pos.ID != 0 && id <= pos.ID {
break OUTER
}
deliveries = append(deliveries, d)
f.Logf("Received %T at %s: %v", payload, deliveredAt, payload)
if deliveredAt.After(pos.DeliveredAt) {
pos.DeliveredAt = deliveredAt.Time
}
if id > pos.ID {
pos.ID = id
}
}
if opts.Cursor == "" {
break
}
time.Sleep(1 * time.Second)
}
sort.Slice(deliveries, func(a, b int) bool {
return deliveries[b].GetDeliveredAt().After(deliveries[a].GetDeliveredAt().Time)
})
var payloads [][]byte
for _, d := range deliveries {
payloads = append(payloads, *d.Request.RawPayload)
}
return payloads, &pos, nil
}