From c014b45a87c8244754a5ab1675859f0a87ae7b57 Mon Sep 17 00:00:00 2001 From: Adel Salakh Date: Thu, 15 May 2025 01:03:08 +0200 Subject: [PATCH] refactor --- main.go | 17 ++++++++-- oauthproxy.go | 30 ++++------------- oauthproxy_test.go | 48 +++++++++++++-------------- pkg/middleware/readynesscheck.go | 6 ++-- pkg/middleware/readynesscheck_test.go | 2 +- 5 files changed, 49 insertions(+), 54 deletions(-) diff --git a/main.go b/main.go index 4bef9ad4..c3877bce 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,12 @@ package main import ( + "context" "fmt" "os" + "os/signal" "runtime" + "syscall" "github.com/ghodss/yaml" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" @@ -55,12 +58,22 @@ func main() { validator := NewValidator(opts.EmailDomains, opts.AuthenticatedEmailsFile) - oauthproxy, err := NewOAuthProxy(opts, validator) + ctx, cancelFunc := context.WithCancel(context.Background()) + + // Observe signals in background goroutine. + go func() { + sigint := make(chan os.Signal, 1) + signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) + <-sigint + cancelFunc() // cancel the context + }() + + oauthproxy, err := NewOAuthProxy(ctx, opts, validator) if err != nil { logger.Fatalf("ERROR: Failed to initialise OAuth2 Proxy: %v", err) } - if err := oauthproxy.Start(); err != nil { + if err := oauthproxy.Start(ctx); err != nil { logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err) } } diff --git a/oauthproxy.go b/oauthproxy.go index d8ec273e..8439e1ce 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -10,11 +10,8 @@ import ( "net" "net/http" "net/url" - "os" - "os/signal" "regexp" "strings" - "syscall" "time" "github.com/gorilla/mux" @@ -113,14 +110,11 @@ type OAuthProxy struct { redirectValidator redirect.Validator appDirector redirect.AppDirector - cancelFunc context.CancelFunc - cancelCtx context.Context - encodeState bool } // NewOAuthProxy creates a new instance of OAuthProxy from the options provided -func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthProxy, error) { +func NewOAuthProxy(ctx context.Context, opts *options.Options, validator func(string) bool) (*OAuthProxy, error) { sessionStore, err := sessions.NewSessionStore(&opts.Session, &opts.Cookie) if err != nil { return nil, fmt.Errorf("error initialising session store: %v", err) @@ -203,9 +197,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr return nil, err } - cancelCtx, cancelFunc := context.WithCancel(context.Background()) - - preAuthChain, err := buildPreAuthChain(opts, sessionStore, cancelCtx) + preAuthChain, err := buildPreAuthChain(ctx, opts, sessionStore) if err != nil { return nil, fmt.Errorf("could not build pre-auth chain: %v", err) } @@ -253,8 +245,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr redirectValidator: redirectValidator, appDirector: appDirector, encodeState: opts.EncodeState, - cancelFunc: cancelFunc, - cancelCtx: cancelCtx, } p.buildServeMux(opts.ProxyPrefix) @@ -265,22 +255,14 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr return p, nil } -func (p *OAuthProxy) Start() error { +func (p *OAuthProxy) Start(ctx context.Context) error { if p.server == nil { // We have to call setupServer before Start is called. // If this doesn't happen it's a programming error. panic("server has not been initialised") } - // Observe signals in background goroutine. - go func() { - sigint := make(chan os.Signal, 1) - signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) - <-sigint - p.cancelFunc() // cancel the context - }() - - return p.server.Start(p.cancelCtx) + return p.server.Start(ctx) } func (p *OAuthProxy) setupServer(opts *options.Options) error { @@ -359,7 +341,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) { // buildPreAuthChain constructs a chain that should process every request before // the OAuth2 Proxy authentication logic kicks in. // For example forcing HTTPS or health checks. -func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionStore, ctx context.Context) (alice.Chain, error) { +func buildPreAuthChain(ctx context.Context, opts *options.Options, sessionStore sessionsapi.SessionStore) (alice.Chain, error) { chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader)) if opts.ForceHTTPS { @@ -380,7 +362,7 @@ func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionSt // To silence logging of health checks, register the health check handler before // the logging handler - readinessCheck := middleware.NewReadynessCheck(opts.ReadyPath, sessionStore, ctx) + readinessCheck := middleware.NewReadynessCheck(ctx, opts.ReadyPath, sessionStore) if opts.Logging.SilencePing { chain = chain.Append( middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 488b8cea..0f0edbbd 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -47,7 +47,7 @@ func TestRobotsTxt(t *testing.T) { err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true }) if err != nil { t.Fatal(err) } @@ -108,7 +108,7 @@ func Test_redeemCode(t *testing.T) { err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true }) if err != nil { t.Fatal(err) } @@ -160,7 +160,7 @@ func Test_enrichSession(t *testing.T) { err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true }) if err != nil { t.Fatal(err) } @@ -229,7 +229,7 @@ func TestBasicAuthPassword(t *testing.T) { providerURL, _ := url.Parse(providerServer.URL) const emailAddress = "john.doe@example.com" - proxy, err := NewOAuthProxy(opts, func(email string) bool { + proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool { return email == emailAddress }) if err != nil { @@ -291,7 +291,7 @@ func TestPassGroupsHeadersWithGroups(t *testing.T) { CreatedAt: &created, } - proxy, err := NewOAuthProxy(opts, func(email string) bool { + proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool { return email == emailAddress }) assert.NoError(t, err) @@ -387,7 +387,7 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTe testProvider := NewTestProvider(providerURL, emailAddress) testProvider.ValidToken = opts.ValidToken - patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool { + patt.proxy, err = NewOAuthProxy(context.Background(), patt.opts, func(email string) bool { return email == emailAddress }) patt.proxy.provider = testProvider @@ -592,7 +592,7 @@ func NewSignInPageTest(skipProvider bool) (*SignInPageTest, error) { return nil, err } - sipTest.proxy, err = NewOAuthProxy(sipTest.opts, func(email string) bool { + sipTest.proxy, err = NewOAuthProxy(context.Background(), sipTest.opts, func(email string) bool { return true }) if err != nil { @@ -628,7 +628,7 @@ func TestManualSignInStoresUserGroupsInTheSession(t *testing.T) { t.Fatal(err) } - proxy, err := NewOAuthProxy(opts, func(email string) bool { + proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool { return true }) if err != nil { @@ -676,7 +676,7 @@ func ManualSignInWithCredentials(t *testing.T, user, pass string) int { t.Fatal(err) } - proxy, err := NewOAuthProxy(opts, func(email string) bool { + proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool { return true }) if err != nil { @@ -826,7 +826,7 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi return nil, err } - pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { + pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { @@ -1200,7 +1200,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { err := validation.Validate(pcTest.opts) assert.NoError(t, err) - pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { + pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { @@ -1293,7 +1293,7 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { err := validation.Validate(pcTest.opts) assert.NoError(t, err) - pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { + pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { @@ -1373,7 +1373,7 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { err := validation.Validate(pcTest.opts) assert.NoError(t, err) - pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { + pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { @@ -1429,7 +1429,7 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) { upstreamURL, _ := url.Parse(upstreamServer.URL) - proxy, err := NewOAuthProxy(opts, func(string) bool { return false }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return false }) if err != nil { t.Fatal(err) } @@ -1554,7 +1554,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) er if err != nil { return err } - proxy, err := NewOAuthProxy(st.opts, func(email string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), st.opts, func(email string) bool { return true }) if err != nil { return err } @@ -1642,7 +1642,7 @@ func newAjaxRequestTest(forceJSONErrors bool) (*ajaxRequestTest, error) { return nil, err } - test.proxy, err = NewOAuthProxy(test.opts, func(email string) bool { + test.proxy, err = NewOAuthProxy(context.Background(), test.opts, func(email string) bool { return true }) if err != nil { @@ -1958,7 +1958,7 @@ func Test_noCacheHeaders(t *testing.T) { opts.SkipAuthRegex = []string{".*"} err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true }) if err != nil { t.Fatal(err) } @@ -2233,7 +2233,7 @@ func TestTrustedIPs(t *testing.T) { err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true }) assert.NoError(t, err) rw := httptest.NewRecorder() @@ -2478,7 +2478,7 @@ func TestApiRoutes(t *testing.T) { opts.SkipProviderButton = true err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true }) if err != nil { t.Fatal(err) } @@ -2561,7 +2561,7 @@ func TestAllowedRequest(t *testing.T) { } err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true }) if err != nil { t.Fatal(err) } @@ -2660,7 +2660,7 @@ func TestAllowedRequestWithForwardedUriHeader(t *testing.T) { } err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true }) if err != nil { t.Fatal(err) } @@ -2759,7 +2759,7 @@ func TestAllowedRequestNegateWithoutMethod(t *testing.T) { } err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true }) if err != nil { t.Fatal(err) } @@ -2859,7 +2859,7 @@ func TestAllowedRequestNegateWithMethod(t *testing.T) { } err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true }) if err != nil { t.Fatal(err) } @@ -3460,7 +3460,7 @@ func TestGetOAuthRedirectURI(t *testing.T) { err := validation.Validate(baseOpts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(tt.setupOpts(baseOpts), func(string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), tt.setupOpts(baseOpts), func(string) bool { return true }) assert.NoError(t, err) assert.Equalf(t, tt.want, proxy.getOAuthRedirectURI(tt.req), "getOAuthRedirectURI(%v)", tt.req) diff --git a/pkg/middleware/readynesscheck.go b/pkg/middleware/readynesscheck.go index bf3e6373..6977cf2e 100644 --- a/pkg/middleware/readynesscheck.go +++ b/pkg/middleware/readynesscheck.go @@ -16,13 +16,13 @@ type Verifiable interface { // NewReadynessCheck returns a middleware that performs deep health checks // (verifies the connection to any underlying store) on a specific `path` -func NewReadynessCheck(path string, verifiable Verifiable, ctx context.Context) alice.Constructor { +func NewReadynessCheck(ctx context.Context, path string, verifiable Verifiable) alice.Constructor { return func(next http.Handler) http.Handler { - return readynessCheck(path, verifiable, next, ctx) + return readynessCheck(ctx, path, verifiable, next) } } -func readynessCheck(path string, verifiable Verifiable, next http.Handler, ctx context.Context) http.Handler { +func readynessCheck(ctx context.Context, path string, verifiable Verifiable, next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { if ctx.Err() != nil { rw.WriteHeader(http.StatusServiceUnavailable) diff --git a/pkg/middleware/readynesscheck_test.go b/pkg/middleware/readynesscheck_test.go index a00b8d88..952fc3f0 100644 --- a/pkg/middleware/readynesscheck_test.go +++ b/pkg/middleware/readynesscheck_test.go @@ -27,7 +27,7 @@ var _ = Describe("ReadynessCheck suite", func() { ctx := context.Background() - handler := NewReadynessCheck(in.readyPath, in.healthVerifiable, ctx)(http.NotFoundHandler()) + handler := NewReadynessCheck(ctx, in.readyPath, in.healthVerifiable)(http.NotFoundHandler()) handler.ServeHTTP(rw, req) Expect(rw.Code).To(Equal(in.expectedStatus))