refactor
This commit is contained in:
		
							parent
							
								
									eecea5a1be
								
							
						
					
					
						commit
						c014b45a87
					
				
							
								
								
									
										17
									
								
								main.go
								
								
								
								
							
							
						
						
									
										17
									
								
								main.go
								
								
								
								
							| 
						 | 
					@ -1,9 +1,12 @@
 | 
				
			||||||
package main
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
 | 
						"os/signal"
 | 
				
			||||||
	"runtime"
 | 
						"runtime"
 | 
				
			||||||
 | 
						"syscall"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/ghodss/yaml"
 | 
						"github.com/ghodss/yaml"
 | 
				
			||||||
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
 | 
						"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
 | 
				
			||||||
| 
						 | 
					@ -55,12 +58,22 @@ func main() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	validator := NewValidator(opts.EmailDomains, opts.AuthenticatedEmailsFile)
 | 
						validator := NewValidator(opts.EmailDomains, opts.AuthenticatedEmailsFile)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	oauthproxy, err := NewOAuthProxy(opts, validator)
 | 
						ctx, cancelFunc := context.WithCancel(context.Background())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Observe signals in background goroutine.
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							sigint := make(chan os.Signal, 1)
 | 
				
			||||||
 | 
							signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
 | 
				
			||||||
 | 
							<-sigint
 | 
				
			||||||
 | 
							cancelFunc() // cancel the context
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						oauthproxy, err := NewOAuthProxy(ctx, opts, validator)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logger.Fatalf("ERROR: Failed to initialise OAuth2 Proxy: %v", err)
 | 
							logger.Fatalf("ERROR: Failed to initialise OAuth2 Proxy: %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := oauthproxy.Start(); err != nil {
 | 
						if err := oauthproxy.Start(ctx); err != nil {
 | 
				
			||||||
		logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err)
 | 
							logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -10,11 +10,8 @@ import (
 | 
				
			||||||
	"net"
 | 
						"net"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
	"os"
 | 
					 | 
				
			||||||
	"os/signal"
 | 
					 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"syscall"
 | 
					 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/gorilla/mux"
 | 
						"github.com/gorilla/mux"
 | 
				
			||||||
| 
						 | 
					@ -113,14 +110,11 @@ type OAuthProxy struct {
 | 
				
			||||||
	redirectValidator redirect.Validator
 | 
						redirectValidator redirect.Validator
 | 
				
			||||||
	appDirector       redirect.AppDirector
 | 
						appDirector       redirect.AppDirector
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	cancelFunc context.CancelFunc
 | 
					 | 
				
			||||||
	cancelCtx  context.Context
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	encodeState bool
 | 
						encodeState bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
 | 
					// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
 | 
				
			||||||
func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthProxy, error) {
 | 
					func NewOAuthProxy(ctx context.Context, opts *options.Options, validator func(string) bool) (*OAuthProxy, error) {
 | 
				
			||||||
	sessionStore, err := sessions.NewSessionStore(&opts.Session, &opts.Cookie)
 | 
						sessionStore, err := sessions.NewSessionStore(&opts.Session, &opts.Cookie)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, fmt.Errorf("error initialising session store: %v", err)
 | 
							return nil, fmt.Errorf("error initialising session store: %v", err)
 | 
				
			||||||
| 
						 | 
					@ -203,9 +197,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	cancelCtx, cancelFunc := context.WithCancel(context.Background())
 | 
						preAuthChain, err := buildPreAuthChain(ctx, opts, sessionStore)
 | 
				
			||||||
 | 
					 | 
				
			||||||
	preAuthChain, err := buildPreAuthChain(opts, sessionStore, cancelCtx)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, fmt.Errorf("could not build pre-auth chain: %v", err)
 | 
							return nil, fmt.Errorf("could not build pre-auth chain: %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -253,8 +245,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
 | 
				
			||||||
		redirectValidator:  redirectValidator,
 | 
							redirectValidator:  redirectValidator,
 | 
				
			||||||
		appDirector:        appDirector,
 | 
							appDirector:        appDirector,
 | 
				
			||||||
		encodeState:        opts.EncodeState,
 | 
							encodeState:        opts.EncodeState,
 | 
				
			||||||
		cancelFunc:         cancelFunc,
 | 
					 | 
				
			||||||
		cancelCtx:          cancelCtx,
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	p.buildServeMux(opts.ProxyPrefix)
 | 
						p.buildServeMux(opts.ProxyPrefix)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -265,22 +255,14 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
 | 
				
			||||||
	return p, nil
 | 
						return p, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OAuthProxy) Start() error {
 | 
					func (p *OAuthProxy) Start(ctx context.Context) error {
 | 
				
			||||||
	if p.server == nil {
 | 
						if p.server == nil {
 | 
				
			||||||
		// We have to call setupServer before Start is called.
 | 
							// We have to call setupServer before Start is called.
 | 
				
			||||||
		// If this doesn't happen it's a programming error.
 | 
							// If this doesn't happen it's a programming error.
 | 
				
			||||||
		panic("server has not been initialised")
 | 
							panic("server has not been initialised")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Observe signals in background goroutine.
 | 
						return p.server.Start(ctx)
 | 
				
			||||||
	go func() {
 | 
					 | 
				
			||||||
		sigint := make(chan os.Signal, 1)
 | 
					 | 
				
			||||||
		signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
 | 
					 | 
				
			||||||
		<-sigint
 | 
					 | 
				
			||||||
		p.cancelFunc() // cancel the context
 | 
					 | 
				
			||||||
	}()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return p.server.Start(p.cancelCtx)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *OAuthProxy) setupServer(opts *options.Options) error {
 | 
					func (p *OAuthProxy) setupServer(opts *options.Options) error {
 | 
				
			||||||
| 
						 | 
					@ -359,7 +341,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) {
 | 
				
			||||||
// buildPreAuthChain constructs a chain that should process every request before
 | 
					// buildPreAuthChain constructs a chain that should process every request before
 | 
				
			||||||
// the OAuth2 Proxy authentication logic kicks in.
 | 
					// the OAuth2 Proxy authentication logic kicks in.
 | 
				
			||||||
// For example forcing HTTPS or health checks.
 | 
					// For example forcing HTTPS or health checks.
 | 
				
			||||||
func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionStore, ctx context.Context) (alice.Chain, error) {
 | 
					func buildPreAuthChain(ctx context.Context, opts *options.Options, sessionStore sessionsapi.SessionStore) (alice.Chain, error) {
 | 
				
			||||||
	chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader))
 | 
						chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if opts.ForceHTTPS {
 | 
						if opts.ForceHTTPS {
 | 
				
			||||||
| 
						 | 
					@ -380,7 +362,7 @@ func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionSt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// To silence logging of health checks, register the health check handler before
 | 
						// To silence logging of health checks, register the health check handler before
 | 
				
			||||||
	// the logging handler
 | 
						// the logging handler
 | 
				
			||||||
	readinessCheck := middleware.NewReadynessCheck(opts.ReadyPath, sessionStore, ctx)
 | 
						readinessCheck := middleware.NewReadynessCheck(ctx, opts.ReadyPath, sessionStore)
 | 
				
			||||||
	if opts.Logging.SilencePing {
 | 
						if opts.Logging.SilencePing {
 | 
				
			||||||
		chain = chain.Append(
 | 
							chain = chain.Append(
 | 
				
			||||||
			middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents),
 | 
								middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents),
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -47,7 +47,7 @@ func TestRobotsTxt(t *testing.T) {
 | 
				
			||||||
	err := validation.Validate(opts)
 | 
						err := validation.Validate(opts)
 | 
				
			||||||
	assert.NoError(t, err)
 | 
						assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
 | 
						proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true })
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -108,7 +108,7 @@ func Test_redeemCode(t *testing.T) {
 | 
				
			||||||
	err := validation.Validate(opts)
 | 
						err := validation.Validate(opts)
 | 
				
			||||||
	assert.NoError(t, err)
 | 
						assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
 | 
						proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true })
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -160,7 +160,7 @@ func Test_enrichSession(t *testing.T) {
 | 
				
			||||||
			err := validation.Validate(opts)
 | 
								err := validation.Validate(opts)
 | 
				
			||||||
			assert.NoError(t, err)
 | 
								assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
 | 
								proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true })
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				t.Fatal(err)
 | 
									t.Fatal(err)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
| 
						 | 
					@ -229,7 +229,7 @@ func TestBasicAuthPassword(t *testing.T) {
 | 
				
			||||||
	providerURL, _ := url.Parse(providerServer.URL)
 | 
						providerURL, _ := url.Parse(providerServer.URL)
 | 
				
			||||||
	const emailAddress = "john.doe@example.com"
 | 
						const emailAddress = "john.doe@example.com"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	proxy, err := NewOAuthProxy(opts, func(email string) bool {
 | 
						proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool {
 | 
				
			||||||
		return email == emailAddress
 | 
							return email == emailAddress
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -291,7 +291,7 @@ func TestPassGroupsHeadersWithGroups(t *testing.T) {
 | 
				
			||||||
		CreatedAt:   &created,
 | 
							CreatedAt:   &created,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	proxy, err := NewOAuthProxy(opts, func(email string) bool {
 | 
						proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool {
 | 
				
			||||||
		return email == emailAddress
 | 
							return email == emailAddress
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	assert.NoError(t, err)
 | 
						assert.NoError(t, err)
 | 
				
			||||||
| 
						 | 
					@ -387,7 +387,7 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTe
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	testProvider := NewTestProvider(providerURL, emailAddress)
 | 
						testProvider := NewTestProvider(providerURL, emailAddress)
 | 
				
			||||||
	testProvider.ValidToken = opts.ValidToken
 | 
						testProvider.ValidToken = opts.ValidToken
 | 
				
			||||||
	patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool {
 | 
						patt.proxy, err = NewOAuthProxy(context.Background(), patt.opts, func(email string) bool {
 | 
				
			||||||
		return email == emailAddress
 | 
							return email == emailAddress
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	patt.proxy.provider = testProvider
 | 
						patt.proxy.provider = testProvider
 | 
				
			||||||
| 
						 | 
					@ -592,7 +592,7 @@ func NewSignInPageTest(skipProvider bool) (*SignInPageTest, error) {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	sipTest.proxy, err = NewOAuthProxy(sipTest.opts, func(email string) bool {
 | 
						sipTest.proxy, err = NewOAuthProxy(context.Background(), sipTest.opts, func(email string) bool {
 | 
				
			||||||
		return true
 | 
							return true
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -628,7 +628,7 @@ func TestManualSignInStoresUserGroupsInTheSession(t *testing.T) {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	proxy, err := NewOAuthProxy(opts, func(email string) bool {
 | 
						proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool {
 | 
				
			||||||
		return true
 | 
							return true
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -676,7 +676,7 @@ func ManualSignInWithCredentials(t *testing.T, user, pass string) int {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	proxy, err := NewOAuthProxy(opts, func(email string) bool {
 | 
						proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool {
 | 
				
			||||||
		return true
 | 
							return true
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -826,7 +826,7 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
 | 
						pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool {
 | 
				
			||||||
		return pcTest.validateUser
 | 
							return pcTest.validateUser
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -1200,7 +1200,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
 | 
				
			||||||
	err := validation.Validate(pcTest.opts)
 | 
						err := validation.Validate(pcTest.opts)
 | 
				
			||||||
	assert.NoError(t, err)
 | 
						assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
 | 
						pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool {
 | 
				
			||||||
		return pcTest.validateUser
 | 
							return pcTest.validateUser
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -1293,7 +1293,7 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
 | 
				
			||||||
	err := validation.Validate(pcTest.opts)
 | 
						err := validation.Validate(pcTest.opts)
 | 
				
			||||||
	assert.NoError(t, err)
 | 
						assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
 | 
						pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool {
 | 
				
			||||||
		return pcTest.validateUser
 | 
							return pcTest.validateUser
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -1373,7 +1373,7 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
 | 
				
			||||||
	err := validation.Validate(pcTest.opts)
 | 
						err := validation.Validate(pcTest.opts)
 | 
				
			||||||
	assert.NoError(t, err)
 | 
						assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool {
 | 
						pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool {
 | 
				
			||||||
		return pcTest.validateUser
 | 
							return pcTest.validateUser
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -1429,7 +1429,7 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	upstreamURL, _ := url.Parse(upstreamServer.URL)
 | 
						upstreamURL, _ := url.Parse(upstreamServer.URL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	proxy, err := NewOAuthProxy(opts, func(string) bool { return false })
 | 
						proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return false })
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -1554,7 +1554,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) er
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	proxy, err := NewOAuthProxy(st.opts, func(email string) bool { return true })
 | 
						proxy, err := NewOAuthProxy(context.Background(), st.opts, func(email string) bool { return true })
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -1642,7 +1642,7 @@ func newAjaxRequestTest(forceJSONErrors bool) (*ajaxRequestTest, error) {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	test.proxy, err = NewOAuthProxy(test.opts, func(email string) bool {
 | 
						test.proxy, err = NewOAuthProxy(context.Background(), test.opts, func(email string) bool {
 | 
				
			||||||
		return true
 | 
							return true
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -1958,7 +1958,7 @@ func Test_noCacheHeaders(t *testing.T) {
 | 
				
			||||||
	opts.SkipAuthRegex = []string{".*"}
 | 
						opts.SkipAuthRegex = []string{".*"}
 | 
				
			||||||
	err := validation.Validate(opts)
 | 
						err := validation.Validate(opts)
 | 
				
			||||||
	assert.NoError(t, err)
 | 
						assert.NoError(t, err)
 | 
				
			||||||
	proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
 | 
						proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -2233,7 +2233,7 @@ func TestTrustedIPs(t *testing.T) {
 | 
				
			||||||
			err := validation.Validate(opts)
 | 
								err := validation.Validate(opts)
 | 
				
			||||||
			assert.NoError(t, err)
 | 
								assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
 | 
								proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true })
 | 
				
			||||||
			assert.NoError(t, err)
 | 
								assert.NoError(t, err)
 | 
				
			||||||
			rw := httptest.NewRecorder()
 | 
								rw := httptest.NewRecorder()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2478,7 +2478,7 @@ func TestApiRoutes(t *testing.T) {
 | 
				
			||||||
	opts.SkipProviderButton = true
 | 
						opts.SkipProviderButton = true
 | 
				
			||||||
	err := validation.Validate(opts)
 | 
						err := validation.Validate(opts)
 | 
				
			||||||
	assert.NoError(t, err)
 | 
						assert.NoError(t, err)
 | 
				
			||||||
	proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
 | 
						proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -2561,7 +2561,7 @@ func TestAllowedRequest(t *testing.T) {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err := validation.Validate(opts)
 | 
						err := validation.Validate(opts)
 | 
				
			||||||
	assert.NoError(t, err)
 | 
						assert.NoError(t, err)
 | 
				
			||||||
	proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
 | 
						proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -2660,7 +2660,7 @@ func TestAllowedRequestWithForwardedUriHeader(t *testing.T) {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err := validation.Validate(opts)
 | 
						err := validation.Validate(opts)
 | 
				
			||||||
	assert.NoError(t, err)
 | 
						assert.NoError(t, err)
 | 
				
			||||||
	proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
 | 
						proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -2759,7 +2759,7 @@ func TestAllowedRequestNegateWithoutMethod(t *testing.T) {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err := validation.Validate(opts)
 | 
						err := validation.Validate(opts)
 | 
				
			||||||
	assert.NoError(t, err)
 | 
						assert.NoError(t, err)
 | 
				
			||||||
	proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
 | 
						proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -2859,7 +2859,7 @@ func TestAllowedRequestNegateWithMethod(t *testing.T) {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err := validation.Validate(opts)
 | 
						err := validation.Validate(opts)
 | 
				
			||||||
	assert.NoError(t, err)
 | 
						assert.NoError(t, err)
 | 
				
			||||||
	proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true })
 | 
						proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
| 
						 | 
					@ -3460,7 +3460,7 @@ func TestGetOAuthRedirectURI(t *testing.T) {
 | 
				
			||||||
			err := validation.Validate(baseOpts)
 | 
								err := validation.Validate(baseOpts)
 | 
				
			||||||
			assert.NoError(t, err)
 | 
								assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			proxy, err := NewOAuthProxy(tt.setupOpts(baseOpts), func(string) bool { return true })
 | 
								proxy, err := NewOAuthProxy(context.Background(), tt.setupOpts(baseOpts), func(string) bool { return true })
 | 
				
			||||||
			assert.NoError(t, err)
 | 
								assert.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			assert.Equalf(t, tt.want, proxy.getOAuthRedirectURI(tt.req), "getOAuthRedirectURI(%v)", tt.req)
 | 
								assert.Equalf(t, tt.want, proxy.getOAuthRedirectURI(tt.req), "getOAuthRedirectURI(%v)", tt.req)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,13 +16,13 @@ type Verifiable interface {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewReadynessCheck returns a middleware that performs deep health checks
 | 
					// NewReadynessCheck returns a middleware that performs deep health checks
 | 
				
			||||||
// (verifies the connection to any underlying store) on a specific `path`
 | 
					// (verifies the connection to any underlying store) on a specific `path`
 | 
				
			||||||
func NewReadynessCheck(path string, verifiable Verifiable, ctx context.Context) alice.Constructor {
 | 
					func NewReadynessCheck(ctx context.Context, path string, verifiable Verifiable) alice.Constructor {
 | 
				
			||||||
	return func(next http.Handler) http.Handler {
 | 
						return func(next http.Handler) http.Handler {
 | 
				
			||||||
		return readynessCheck(path, verifiable, next, ctx)
 | 
							return readynessCheck(ctx, path, verifiable, next)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func readynessCheck(path string, verifiable Verifiable, next http.Handler, ctx context.Context) http.Handler {
 | 
					func readynessCheck(ctx context.Context, path string, verifiable Verifiable, next http.Handler) http.Handler {
 | 
				
			||||||
	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | 
						return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
 | 
				
			||||||
		if ctx.Err() != nil {
 | 
							if ctx.Err() != nil {
 | 
				
			||||||
			rw.WriteHeader(http.StatusServiceUnavailable)
 | 
								rw.WriteHeader(http.StatusServiceUnavailable)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -27,7 +27,7 @@ var _ = Describe("ReadynessCheck suite", func() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			ctx := context.Background()
 | 
								ctx := context.Background()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			handler := NewReadynessCheck(in.readyPath, in.healthVerifiable, ctx)(http.NotFoundHandler())
 | 
								handler := NewReadynessCheck(ctx, in.readyPath, in.healthVerifiable)(http.NotFoundHandler())
 | 
				
			||||||
			handler.ServeHTTP(rw, req)
 | 
								handler.ServeHTTP(rw, req)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			Expect(rw.Code).To(Equal(in.expectedStatus))
 | 
								Expect(rw.Code).To(Equal(in.expectedStatus))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue