254 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			254 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			Go
		
	
	
	
package hookdeliveryforwarder
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"context"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"net/http"
 | 
						|
	"os"
 | 
						|
	"sort"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/actions/actions-runner-controller/github"
 | 
						|
	gogithub "github.com/google/go-github/v52/github"
 | 
						|
)
 | 
						|
 | 
						|
type Forwarder struct {
 | 
						|
	Repo   string
 | 
						|
	Target string
 | 
						|
 | 
						|
	Hook gogithub.Hook
 | 
						|
 | 
						|
	PollingDelay time.Duration
 | 
						|
 | 
						|
	Client *github.Client
 | 
						|
 | 
						|
	Checkpointer Checkpointer
 | 
						|
 | 
						|
	logger
 | 
						|
}
 | 
						|
 | 
						|
type persistentError struct {
 | 
						|
	Err error
 | 
						|
}
 | 
						|
 | 
						|
func (e persistentError) Error() string {
 | 
						|
	return fmt.Sprintf("%v", e.Err)
 | 
						|
}
 | 
						|
 | 
						|
func (f *Forwarder) Run(ctx context.Context) error {
 | 
						|
	pollingDelay := 10 * time.Second
 | 
						|
	if f.PollingDelay > 0 {
 | 
						|
		pollingDelay = f.PollingDelay
 | 
						|
	}
 | 
						|
 | 
						|
	segments := strings.Split(f.Repo, "/")
 | 
						|
 | 
						|
	owner := segments[0]
 | 
						|
 | 
						|
	var repo string
 | 
						|
 | 
						|
	if len(segments) > 1 {
 | 
						|
		repo = segments[1]
 | 
						|
	}
 | 
						|
 | 
						|
	hooksAPI := newHooksAPI(f.Client.Client, owner, repo)
 | 
						|
 | 
						|
	hooks, _, err := hooksAPI.ListHooks(ctx, nil)
 | 
						|
	if err != nil {
 | 
						|
		f.Errorf("Failed listing hooks: %v", err)
 | 
						|
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	var hook *gogithub.Hook
 | 
						|
 | 
						|
	for i := range hooks {
 | 
						|
		hook = hooks[i]
 | 
						|
		break
 | 
						|
	}
 | 
						|
 | 
						|
	if hook == nil {
 | 
						|
		hookConfig := &f.Hook
 | 
						|
 | 
						|
		if _, ok := hookConfig.Config["url"]; !ok {
 | 
						|
			return persistentError{Err: fmt.Errorf("config.url is missing in the hook config")}
 | 
						|
		}
 | 
						|
 | 
						|
		if _, ok := hookConfig.Config["content_type"]; !ok {
 | 
						|
			hookConfig.Config["content_type"] = "json"
 | 
						|
		}
 | 
						|
 | 
						|
		if _, ok := hookConfig.Config["insecure_ssl"]; !ok {
 | 
						|
			hookConfig.Config["insecure_ssl"] = 0
 | 
						|
		}
 | 
						|
 | 
						|
		if _, ok := hookConfig.Config["secret"]; !ok {
 | 
						|
			hookConfig.Config["secret"] = os.Getenv("GITHUB_HOOK_SECRET")
 | 
						|
		}
 | 
						|
 | 
						|
		if len(hookConfig.Events) == 0 {
 | 
						|
			hookConfig.Events = []string{"check_run", "push"}
 | 
						|
		}
 | 
						|
 | 
						|
		if hookConfig.Active == nil {
 | 
						|
			hookConfig.Active = gogithub.Bool(true)
 | 
						|
		}
 | 
						|
 | 
						|
		h, _, err := hooksAPI.CreateHook(ctx, hookConfig)
 | 
						|
		if err != nil {
 | 
						|
			f.Errorf("Failed creating hook: %v", err)
 | 
						|
 | 
						|
			return persistentError{Err: err}
 | 
						|
		}
 | 
						|
 | 
						|
		hook = h
 | 
						|
	}
 | 
						|
 | 
						|
	f.Logf("Using this hook for receiving deliveries to be forwarded: %+v", *hook)
 | 
						|
 | 
						|
	hookDeliveries := newHookDeliveriesAPI(f.Client.Client, owner, repo, hook.GetID())
 | 
						|
 | 
						|
	cur, err := f.Checkpointer.GetOrCreate(hook.GetID())
 | 
						|
	if err != nil {
 | 
						|
		f.Errorf("Failed to get or create log position: %v", err)
 | 
						|
 | 
						|
		return persistentError{Err: err}
 | 
						|
	}
 | 
						|
 | 
						|
LOOP:
 | 
						|
	for {
 | 
						|
		var (
 | 
						|
			err      error
 | 
						|
			payloads [][]byte
 | 
						|
		)
 | 
						|
 | 
						|
		payloads, cur, err = f.getUnprocessedDeliveries(ctx, hookDeliveries, *cur)
 | 
						|
		if err != nil {
 | 
						|
			f.Errorf("failed getting unprocessed deliveries: %v", err)
 | 
						|
 | 
						|
			if errors.Is(err, context.Canceled) {
 | 
						|
				return err
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		for _, p := range payloads {
 | 
						|
			if _, err := http.Post(f.Target, "application/json", bytes.NewReader(p)); err != nil {
 | 
						|
				f.Errorf("failed forwarding delivery: %v", err)
 | 
						|
 | 
						|
				retryDelay := 5 * time.Second
 | 
						|
				t := time.NewTimer(retryDelay)
 | 
						|
 | 
						|
				select {
 | 
						|
				case <-t.C:
 | 
						|
					t.Stop()
 | 
						|
				case <-ctx.Done():
 | 
						|
					t.Stop()
 | 
						|
 | 
						|
					return ctx.Err()
 | 
						|
				}
 | 
						|
 | 
						|
				continue LOOP
 | 
						|
			} else {
 | 
						|
				f.Logf("Successfully POSTed the payload to %s", f.Target)
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		if err := f.Checkpointer.Update(hook.GetID(), cur); err != nil {
 | 
						|
			return fmt.Errorf("failed updating checkpoint: %w", err)
 | 
						|
		}
 | 
						|
 | 
						|
		t := time.NewTimer(pollingDelay)
 | 
						|
 | 
						|
		select {
 | 
						|
		case <-t.C:
 | 
						|
			t.Stop()
 | 
						|
		case <-ctx.Done():
 | 
						|
			t.Stop()
 | 
						|
 | 
						|
			return ctx.Err()
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
type State struct {
 | 
						|
	DeliveredAt time.Time
 | 
						|
	ID          int64
 | 
						|
}
 | 
						|
 | 
						|
func (f *Forwarder) getUnprocessedDeliveries(ctx context.Context, hookDeliveries *hookDeliveriesAPI, pos State) ([][]byte, *State, error) {
 | 
						|
	var (
 | 
						|
		opts gogithub.ListCursorOptions
 | 
						|
	)
 | 
						|
 | 
						|
	opts.PerPage = 2
 | 
						|
 | 
						|
	var deliveries []*gogithub.HookDelivery
 | 
						|
 | 
						|
OUTER:
 | 
						|
	for {
 | 
						|
		ds, resp, err := hookDeliveries.ListHookDeliveries(ctx, &opts)
 | 
						|
		if err != nil {
 | 
						|
			return nil, nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		opts.Cursor = resp.Cursor
 | 
						|
 | 
						|
		for _, d := range ds {
 | 
						|
			d, _, err := hookDeliveries.GetHookDelivery(ctx, d.GetID())
 | 
						|
			if err != nil {
 | 
						|
				return nil, nil, err
 | 
						|
			}
 | 
						|
 | 
						|
			payload, err := d.ParseRequestPayload()
 | 
						|
			if err != nil {
 | 
						|
				return nil, nil, err
 | 
						|
			}
 | 
						|
 | 
						|
			id := d.GetID()
 | 
						|
			deliveredAt := d.GetDeliveredAt()
 | 
						|
 | 
						|
			if !pos.DeliveredAt.IsZero() && deliveredAt.Before(pos.DeliveredAt) {
 | 
						|
				f.Logf("%s is before %s so skipping all the remaining deliveries", deliveredAt, pos.DeliveredAt)
 | 
						|
				break OUTER
 | 
						|
			}
 | 
						|
 | 
						|
			if pos.ID != 0 && id <= pos.ID {
 | 
						|
				break OUTER
 | 
						|
			}
 | 
						|
 | 
						|
			deliveries = append(deliveries, d)
 | 
						|
 | 
						|
			f.Logf("Received %T at %s: %v", payload, deliveredAt, payload)
 | 
						|
 | 
						|
			if deliveredAt.After(pos.DeliveredAt) {
 | 
						|
				pos.DeliveredAt = deliveredAt.Time
 | 
						|
			}
 | 
						|
 | 
						|
			if id > pos.ID {
 | 
						|
				pos.ID = id
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		if opts.Cursor == "" {
 | 
						|
			break
 | 
						|
		}
 | 
						|
 | 
						|
		time.Sleep(1 * time.Second)
 | 
						|
	}
 | 
						|
 | 
						|
	sort.Slice(deliveries, func(a, b int) bool {
 | 
						|
		return deliveries[b].GetDeliveredAt().After(deliveries[a].GetDeliveredAt().Time)
 | 
						|
	})
 | 
						|
 | 
						|
	var payloads [][]byte
 | 
						|
 | 
						|
	for _, d := range deliveries {
 | 
						|
		payloads = append(payloads, *d.Request.RawPayload)
 | 
						|
	}
 | 
						|
 | 
						|
	return payloads, &pos, nil
 | 
						|
}
 |