203 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			203 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Go
		
	
	
	
package githubwebhookdeliveryforwarder
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"context"
 | 
						|
	"fmt"
 | 
						|
	"net/http"
 | 
						|
	"os"
 | 
						|
	"sort"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/actions-runner-controller/actions-runner-controller/github"
 | 
						|
	gogithub "github.com/google/go-github/v39/github"
 | 
						|
)
 | 
						|
 | 
						|
type server struct {
 | 
						|
	target string
 | 
						|
	Repo   string
 | 
						|
	client *github.Client
 | 
						|
}
 | 
						|
 | 
						|
func New(client *github.Client, target string) *server {
 | 
						|
	var srv server
 | 
						|
 | 
						|
	srv.target = target
 | 
						|
	srv.client = client
 | 
						|
 | 
						|
	return &srv
 | 
						|
}
 | 
						|
 | 
						|
func (s *server) Run(ctx context.Context) error {
 | 
						|
	segments := strings.Split(s.Repo, "/")
 | 
						|
 | 
						|
	if len(segments) != 2 {
 | 
						|
		return fmt.Errorf("repository must be in a form of OWNER/REPO: got %q", s.Repo)
 | 
						|
	}
 | 
						|
 | 
						|
	owner, repo := segments[0], segments[1]
 | 
						|
 | 
						|
	hooks, _, err := s.client.Repositories.ListHooks(ctx, owner, repo, nil)
 | 
						|
	if err != nil {
 | 
						|
		s.Errorf("Failed listing hooks: %v", err)
 | 
						|
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	var hook *gogithub.Hook
 | 
						|
 | 
						|
	for i := range hooks {
 | 
						|
		hook = hooks[i]
 | 
						|
		break
 | 
						|
	}
 | 
						|
 | 
						|
	cur := &cursor{}
 | 
						|
 | 
						|
	cur.deliveredAt = time.Now()
 | 
						|
 | 
						|
	for {
 | 
						|
		var (
 | 
						|
			err      error
 | 
						|
			payloads [][]byte
 | 
						|
		)
 | 
						|
 | 
						|
		payloads, cur, err = s.getUnprocessedDeliveries(ctx, owner, repo, hook.GetID(), *cur)
 | 
						|
		if err != nil {
 | 
						|
			s.Errorf("failed getting unprocessed deliveries: %v", err)
 | 
						|
		}
 | 
						|
 | 
						|
		for _, p := range payloads {
 | 
						|
			if _, err := http.Post(s.target, "application/json", bytes.NewReader(p)); err != nil {
 | 
						|
				s.Errorf("failed forwarding delivery: %v", err)
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		time.Sleep(10 * time.Second)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
type cursor struct {
 | 
						|
	deliveredAt time.Time
 | 
						|
	id          int64
 | 
						|
}
 | 
						|
 | 
						|
func (s *server) getUnprocessedDeliveries(ctx context.Context, owner, repo string, hookID int64, pos cursor) ([][]byte, *cursor, error) {
 | 
						|
	var (
 | 
						|
		opts gogithub.ListCursorOptions
 | 
						|
	)
 | 
						|
 | 
						|
	opts.PerPage = 2
 | 
						|
 | 
						|
	var deliveries []*gogithub.HookDelivery
 | 
						|
 | 
						|
OUTER:
 | 
						|
	for {
 | 
						|
		ds, resp, err := s.client.Repositories.ListHookDeliveries(ctx, owner, repo, hookID, &opts)
 | 
						|
		if err != nil {
 | 
						|
			return nil, nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		opts.Cursor = resp.Cursor
 | 
						|
 | 
						|
		for _, d := range ds {
 | 
						|
			d, _, err := s.client.Repositories.GetHookDelivery(ctx, owner, repo, hookID, d.GetID())
 | 
						|
			if err != nil {
 | 
						|
				return nil, nil, err
 | 
						|
			}
 | 
						|
 | 
						|
			payload, err := d.ParseRequestPayload()
 | 
						|
			if err != nil {
 | 
						|
				return nil, nil, err
 | 
						|
			}
 | 
						|
 | 
						|
			id := d.GetID()
 | 
						|
			deliveredAt := d.GetDeliveredAt()
 | 
						|
 | 
						|
			if !pos.deliveredAt.IsZero() && deliveredAt.Before(pos.deliveredAt) {
 | 
						|
				s.Logf("%s is before %s so skipping all the remaining deliveries", deliveredAt, pos.deliveredAt)
 | 
						|
				break OUTER
 | 
						|
			}
 | 
						|
 | 
						|
			if pos.id != 0 && id <= pos.id {
 | 
						|
				break OUTER
 | 
						|
			}
 | 
						|
 | 
						|
			s.Logf("Received %T at %s: %v", payload, deliveredAt, payload)
 | 
						|
 | 
						|
			if deliveredAt.After(pos.deliveredAt) {
 | 
						|
				pos.deliveredAt = deliveredAt.Time
 | 
						|
			}
 | 
						|
 | 
						|
			if id > pos.id {
 | 
						|
				pos.id = id
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		if opts.Cursor == "" {
 | 
						|
			break
 | 
						|
		}
 | 
						|
 | 
						|
		time.Sleep(1 * time.Second)
 | 
						|
	}
 | 
						|
 | 
						|
	sort.Slice(deliveries, func(a, b int) bool {
 | 
						|
		return deliveries[b].GetDeliveredAt().After(deliveries[a].GetDeliveredAt().Time)
 | 
						|
	})
 | 
						|
 | 
						|
	var payloads [][]byte
 | 
						|
 | 
						|
	for _, d := range deliveries {
 | 
						|
		payloads = append(payloads, *d.Request.RawPayload)
 | 
						|
	}
 | 
						|
 | 
						|
	return payloads, &pos, nil
 | 
						|
}
 | 
						|
 | 
						|
func (s *server) HandleReadyz(w http.ResponseWriter, r *http.Request) {
 | 
						|
	var (
 | 
						|
		ok bool
 | 
						|
 | 
						|
		err error
 | 
						|
	)
 | 
						|
 | 
						|
	defer func() {
 | 
						|
		if !ok {
 | 
						|
			w.WriteHeader(http.StatusInternalServerError)
 | 
						|
 | 
						|
			if err != nil {
 | 
						|
				msg := err.Error()
 | 
						|
				if _, err := w.Write([]byte(msg)); err != nil {
 | 
						|
					s.Errorf("failed writing http error response: %v", err)
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}()
 | 
						|
 | 
						|
	defer func() {
 | 
						|
		if r.Body != nil {
 | 
						|
			r.Body.Close()
 | 
						|
		}
 | 
						|
	}()
 | 
						|
 | 
						|
	// respond ok to GET / e.g. for health check
 | 
						|
	if r.Method == http.MethodGet {
 | 
						|
		fmt.Fprintln(w, "webhook server is running")
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	w.WriteHeader(http.StatusOK)
 | 
						|
 | 
						|
	if _, err := w.Write([]byte("ok")); err != nil {
 | 
						|
		s.Errorf("failed writing http response: %v", err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (s *server) Logf(format string, args ...interface{}) {
 | 
						|
	fmt.Fprintf(os.Stdout, format+"\n", args...)
 | 
						|
}
 | 
						|
 | 
						|
func (s *server) Errorf(format string, args ...interface{}) {
 | 
						|
	fmt.Fprintf(os.Stderr, format+"\n", args...)
 | 
						|
}
 |