Integrate new provider constructor in main
This commit is contained in:
		
							parent
							
								
									2e15f57b70
								
							
						
					
					
						commit
						0791aef8cc
					
				|  | @ -114,6 +114,11 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	provider, err := providers.NewProvider(opts.Providers[0]) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error intiailising provider: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	pageWriter, err := pagewriter.NewWriter(pagewriter.Opts{ | ||||
| 		TemplatesPath:    opts.Templates.Path, | ||||
| 		CustomLogo:       opts.Templates.CustomLogo, | ||||
|  | @ -121,7 +126,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | |||
| 		Footer:           opts.Templates.Footer, | ||||
| 		Version:          VERSION, | ||||
| 		Debug:            opts.Templates.Debug, | ||||
| 		ProviderName:     buildProviderName(opts.GetProvider(), opts.Providers[0].Name), | ||||
| 		ProviderName:     buildProviderName(provider, opts.Providers[0].Name), | ||||
| 		SignInMessage:    buildSignInMessage(opts), | ||||
| 		DisplayLoginForm: basicAuthValidator != nil && opts.Templates.DisplayLoginForm, | ||||
| 	}) | ||||
|  | @ -145,7 +150,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | |||
| 		redirectURL.Path = fmt.Sprintf("%s/callback", opts.ProxyPrefix) | ||||
| 	} | ||||
| 
 | ||||
| 	logger.Printf("OAuthProxy configured for %s Client ID: %s", opts.GetProvider().Data().ProviderName, opts.Providers[0].ClientID) | ||||
| 	logger.Printf("OAuthProxy configured for %s Client ID: %s", provider.Data().ProviderName, opts.Providers[0].ClientID) | ||||
| 	refresh := "disabled" | ||||
| 	if opts.Cookie.Refresh != time.Duration(0) { | ||||
| 		refresh = fmt.Sprintf("after %s", opts.Cookie.Refresh) | ||||
|  | @ -171,7 +176,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | |||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("could not build pre-auth chain: %v", err) | ||||
| 	} | ||||
| 	sessionChain := buildSessionChain(opts, sessionStore, basicAuthValidator) | ||||
| 	sessionChain := buildSessionChain(opts, provider, sessionStore, basicAuthValidator) | ||||
| 	headersChain, err := buildHeadersChain(opts) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("could not build headers chain: %v", err) | ||||
|  | @ -190,7 +195,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | |||
| 		SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix), | ||||
| 
 | ||||
| 		ProxyPrefix:         opts.ProxyPrefix, | ||||
| 		provider:            opts.GetProvider(), | ||||
| 		provider:            provider, | ||||
| 		sessionStore:        sessionStore, | ||||
| 		redirectURL:         redirectURL, | ||||
| 		allowedRoutes:       allowedRoutes, | ||||
|  | @ -346,12 +351,12 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { | |||
| 	return chain, nil | ||||
| } | ||||
| 
 | ||||
| func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain { | ||||
| func buildSessionChain(opts *options.Options, provider providers.Provider, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain { | ||||
| 	chain := alice.New() | ||||
| 
 | ||||
| 	if opts.SkipJwtBearerTokens { | ||||
| 		sessionLoaders := []middlewareapi.TokenToSessionFunc{ | ||||
| 			opts.GetProvider().CreateSessionFromToken, | ||||
| 			provider.CreateSessionFromToken, | ||||
| 		} | ||||
| 
 | ||||
| 		for _, verifier := range opts.GetJWTBearerVerifiers() { | ||||
|  | @ -369,8 +374,8 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt | |||
| 	chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ | ||||
| 		SessionStore:    sessionStore, | ||||
| 		RefreshPeriod:   opts.Cookie.Refresh, | ||||
| 		RefreshSession:  opts.GetProvider().RefreshSession, | ||||
| 		ValidateSession: opts.GetProvider().ValidateSession, | ||||
| 		RefreshSession:  provider.RefreshSession, | ||||
| 		ValidateSession: provider.ValidateSession, | ||||
| 	})) | ||||
| 
 | ||||
| 	return chain | ||||
|  |  | |||
|  | @ -161,13 +161,11 @@ func Test_enrichSession(t *testing.T) { | |||
| 			err := validation.Validate(opts) | ||||
| 			assert.NoError(t, err) | ||||
| 
 | ||||
| 			// intentionally set after validation.Validate(opts) since it will clobber
 | ||||
| 			// our TestProvider and call `providers.New` defaulting to `providers.GoogleProvider`
 | ||||
| 			opts.SetProvider(NewTestProvider(&url.URL{Host: "www.example.com"}, providerEmail)) | ||||
| 			proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) | ||||
| 			if err != nil { | ||||
| 				t.Fatal(err) | ||||
| 			} | ||||
| 			proxy.provider = NewTestProvider(&url.URL{Host: "www.example.com"}, providerEmail) | ||||
| 
 | ||||
| 			err = proxy.enrichSessionState(context.Background(), tc.session) | ||||
| 			assert.NoError(t, err) | ||||
|  | @ -232,13 +230,13 @@ func TestBasicAuthPassword(t *testing.T) { | |||
| 	providerURL, _ := url.Parse(providerServer.URL) | ||||
| 	const emailAddress = "john.doe@example.com" | ||||
| 
 | ||||
| 	opts.SetProvider(NewTestProvider(providerURL, emailAddress)) | ||||
| 	proxy, err := NewOAuthProxy(opts, func(email string) bool { | ||||
| 		return email == emailAddress | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	proxy.provider = NewTestProvider(providerURL, emailAddress) | ||||
| 
 | ||||
| 	// Save the required session
 | ||||
| 	rw := httptest.NewRecorder() | ||||
|  | @ -390,10 +388,10 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTe | |||
| 
 | ||||
| 	testProvider := NewTestProvider(providerURL, emailAddress) | ||||
| 	testProvider.ValidToken = opts.ValidToken | ||||
| 	patt.opts.SetProvider(testProvider) | ||||
| 	patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool { | ||||
| 		return email == emailAddress | ||||
| 	}) | ||||
| 	patt.proxy.provider = testProvider | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | @ -769,11 +767,17 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi | |||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	pcTest.proxy.provider = &TestProvider{ | ||||
| 	testProvider := &TestProvider{ | ||||
| 		ProviderData: &providers.ProviderData{}, | ||||
| 		ValidToken:   opts.providerValidateCookieResponse, | ||||
| 	} | ||||
| 	pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.Providers[0].AllowedGroups) | ||||
| 
 | ||||
| 	groups := pcTest.opts.Providers[0].AllowedGroups | ||||
| 	testProvider.AllowedGroups = make(map[string]struct{}, len(groups)) | ||||
| 	for _, group := range groups { | ||||
| 		testProvider.AllowedGroups[group] = struct{}{} | ||||
| 	} | ||||
| 	pcTest.proxy.provider = testProvider | ||||
| 
 | ||||
| 	// Now, zero-out proxy.CookieRefresh for the cases that don't involve
 | ||||
| 	// access_token validation.
 | ||||
|  | @ -1359,12 +1363,12 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) { | |||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	upstreamURL, _ := url.Parse(upstreamServer.URL) | ||||
| 	opts.SetProvider(NewTestProvider(upstreamURL, "")) | ||||
| 
 | ||||
| 	proxy, err := NewOAuthProxy(opts, func(string) bool { return false }) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	proxy.provider = NewTestProvider(upstreamURL, "") | ||||
| 	rw := httptest.NewRecorder() | ||||
| 	req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil) | ||||
| 	proxy.ServeHTTP(rw, req) | ||||
|  | @ -1409,6 +1413,7 @@ type SignatureTest struct { | |||
| 	header        http.Header | ||||
| 	rw            *httptest.ResponseRecorder | ||||
| 	authenticator *SignatureAuthenticator | ||||
| 	authProvider  providers.Provider | ||||
| } | ||||
| 
 | ||||
| func NewSignatureTest() (*SignatureTest, error) { | ||||
|  | @ -1443,7 +1448,7 @@ func NewSignatureTest() (*SignatureTest, error) { | |||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	opts.SetProvider(NewTestProvider(providerURL, "mbland@acm.org")) | ||||
| 	testProvider := NewTestProvider(providerURL, "mbland@acm.org") | ||||
| 
 | ||||
| 	return &SignatureTest{ | ||||
| 		opts, | ||||
|  | @ -1453,6 +1458,7 @@ func NewSignatureTest() (*SignatureTest, error) { | |||
| 		make(http.Header), | ||||
| 		httptest.NewRecorder(), | ||||
| 		authenticator, | ||||
| 		testProvider, | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
|  | @ -1486,6 +1492,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) er | |||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	proxy.provider = st.authProvider | ||||
| 
 | ||||
| 	var bodyBuf io.ReadCloser | ||||
| 	if body != "" { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue