Integrate new provider constructor in main
This commit is contained in:
		
							parent
							
								
									2e15f57b70
								
							
						
					
					
						commit
						0791aef8cc
					
				|  | @ -114,6 +114,11 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	provider, err := providers.NewProvider(opts.Providers[0]) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("error intiailising provider: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	pageWriter, err := pagewriter.NewWriter(pagewriter.Opts{ | 	pageWriter, err := pagewriter.NewWriter(pagewriter.Opts{ | ||||||
| 		TemplatesPath:    opts.Templates.Path, | 		TemplatesPath:    opts.Templates.Path, | ||||||
| 		CustomLogo:       opts.Templates.CustomLogo, | 		CustomLogo:       opts.Templates.CustomLogo, | ||||||
|  | @ -121,7 +126,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		Footer:           opts.Templates.Footer, | 		Footer:           opts.Templates.Footer, | ||||||
| 		Version:          VERSION, | 		Version:          VERSION, | ||||||
| 		Debug:            opts.Templates.Debug, | 		Debug:            opts.Templates.Debug, | ||||||
| 		ProviderName:     buildProviderName(opts.GetProvider(), opts.Providers[0].Name), | 		ProviderName:     buildProviderName(provider, opts.Providers[0].Name), | ||||||
| 		SignInMessage:    buildSignInMessage(opts), | 		SignInMessage:    buildSignInMessage(opts), | ||||||
| 		DisplayLoginForm: basicAuthValidator != nil && opts.Templates.DisplayLoginForm, | 		DisplayLoginForm: basicAuthValidator != nil && opts.Templates.DisplayLoginForm, | ||||||
| 	}) | 	}) | ||||||
|  | @ -145,7 +150,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		redirectURL.Path = fmt.Sprintf("%s/callback", opts.ProxyPrefix) | 		redirectURL.Path = fmt.Sprintf("%s/callback", opts.ProxyPrefix) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	logger.Printf("OAuthProxy configured for %s Client ID: %s", opts.GetProvider().Data().ProviderName, opts.Providers[0].ClientID) | 	logger.Printf("OAuthProxy configured for %s Client ID: %s", provider.Data().ProviderName, opts.Providers[0].ClientID) | ||||||
| 	refresh := "disabled" | 	refresh := "disabled" | ||||||
| 	if opts.Cookie.Refresh != time.Duration(0) { | 	if opts.Cookie.Refresh != time.Duration(0) { | ||||||
| 		refresh = fmt.Sprintf("after %s", opts.Cookie.Refresh) | 		refresh = fmt.Sprintf("after %s", opts.Cookie.Refresh) | ||||||
|  | @ -171,7 +176,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("could not build pre-auth chain: %v", err) | 		return nil, fmt.Errorf("could not build pre-auth chain: %v", err) | ||||||
| 	} | 	} | ||||||
| 	sessionChain := buildSessionChain(opts, sessionStore, basicAuthValidator) | 	sessionChain := buildSessionChain(opts, provider, sessionStore, basicAuthValidator) | ||||||
| 	headersChain, err := buildHeadersChain(opts) | 	headersChain, err := buildHeadersChain(opts) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("could not build headers chain: %v", err) | 		return nil, fmt.Errorf("could not build headers chain: %v", err) | ||||||
|  | @ -190,7 +195,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix), | 		SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix), | ||||||
| 
 | 
 | ||||||
| 		ProxyPrefix:         opts.ProxyPrefix, | 		ProxyPrefix:         opts.ProxyPrefix, | ||||||
| 		provider:            opts.GetProvider(), | 		provider:            provider, | ||||||
| 		sessionStore:        sessionStore, | 		sessionStore:        sessionStore, | ||||||
| 		redirectURL:         redirectURL, | 		redirectURL:         redirectURL, | ||||||
| 		allowedRoutes:       allowedRoutes, | 		allowedRoutes:       allowedRoutes, | ||||||
|  | @ -346,12 +351,12 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { | ||||||
| 	return chain, nil | 	return chain, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain { | func buildSessionChain(opts *options.Options, provider providers.Provider, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain { | ||||||
| 	chain := alice.New() | 	chain := alice.New() | ||||||
| 
 | 
 | ||||||
| 	if opts.SkipJwtBearerTokens { | 	if opts.SkipJwtBearerTokens { | ||||||
| 		sessionLoaders := []middlewareapi.TokenToSessionFunc{ | 		sessionLoaders := []middlewareapi.TokenToSessionFunc{ | ||||||
| 			opts.GetProvider().CreateSessionFromToken, | 			provider.CreateSessionFromToken, | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		for _, verifier := range opts.GetJWTBearerVerifiers() { | 		for _, verifier := range opts.GetJWTBearerVerifiers() { | ||||||
|  | @ -369,8 +374,8 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt | ||||||
| 	chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ | 	chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ | ||||||
| 		SessionStore:    sessionStore, | 		SessionStore:    sessionStore, | ||||||
| 		RefreshPeriod:   opts.Cookie.Refresh, | 		RefreshPeriod:   opts.Cookie.Refresh, | ||||||
| 		RefreshSession:  opts.GetProvider().RefreshSession, | 		RefreshSession:  provider.RefreshSession, | ||||||
| 		ValidateSession: opts.GetProvider().ValidateSession, | 		ValidateSession: provider.ValidateSession, | ||||||
| 	})) | 	})) | ||||||
| 
 | 
 | ||||||
| 	return chain | 	return chain | ||||||
|  |  | ||||||
|  | @ -161,13 +161,11 @@ func Test_enrichSession(t *testing.T) { | ||||||
| 			err := validation.Validate(opts) | 			err := validation.Validate(opts) | ||||||
| 			assert.NoError(t, err) | 			assert.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 			// intentionally set after validation.Validate(opts) since it will clobber
 |  | ||||||
| 			// our TestProvider and call `providers.New` defaulting to `providers.GoogleProvider`
 |  | ||||||
| 			opts.SetProvider(NewTestProvider(&url.URL{Host: "www.example.com"}, providerEmail)) |  | ||||||
| 			proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) | 			proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				t.Fatal(err) | 				t.Fatal(err) | ||||||
| 			} | 			} | ||||||
|  | 			proxy.provider = NewTestProvider(&url.URL{Host: "www.example.com"}, providerEmail) | ||||||
| 
 | 
 | ||||||
| 			err = proxy.enrichSessionState(context.Background(), tc.session) | 			err = proxy.enrichSessionState(context.Background(), tc.session) | ||||||
| 			assert.NoError(t, err) | 			assert.NoError(t, err) | ||||||
|  | @ -232,13 +230,13 @@ func TestBasicAuthPassword(t *testing.T) { | ||||||
| 	providerURL, _ := url.Parse(providerServer.URL) | 	providerURL, _ := url.Parse(providerServer.URL) | ||||||
| 	const emailAddress = "john.doe@example.com" | 	const emailAddress = "john.doe@example.com" | ||||||
| 
 | 
 | ||||||
| 	opts.SetProvider(NewTestProvider(providerURL, emailAddress)) |  | ||||||
| 	proxy, err := NewOAuthProxy(opts, func(email string) bool { | 	proxy, err := NewOAuthProxy(opts, func(email string) bool { | ||||||
| 		return email == emailAddress | 		return email == emailAddress | ||||||
| 	}) | 	}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|  | 	proxy.provider = NewTestProvider(providerURL, emailAddress) | ||||||
| 
 | 
 | ||||||
| 	// Save the required session
 | 	// Save the required session
 | ||||||
| 	rw := httptest.NewRecorder() | 	rw := httptest.NewRecorder() | ||||||
|  | @ -390,10 +388,10 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTe | ||||||
| 
 | 
 | ||||||
| 	testProvider := NewTestProvider(providerURL, emailAddress) | 	testProvider := NewTestProvider(providerURL, emailAddress) | ||||||
| 	testProvider.ValidToken = opts.ValidToken | 	testProvider.ValidToken = opts.ValidToken | ||||||
| 	patt.opts.SetProvider(testProvider) |  | ||||||
| 	patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool { | 	patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool { | ||||||
| 		return email == emailAddress | 		return email == emailAddress | ||||||
| 	}) | 	}) | ||||||
|  | 	patt.proxy.provider = testProvider | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | @ -769,11 +767,17 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	pcTest.proxy.provider = &TestProvider{ | 	testProvider := &TestProvider{ | ||||||
| 		ProviderData: &providers.ProviderData{}, | 		ProviderData: &providers.ProviderData{}, | ||||||
| 		ValidToken:   opts.providerValidateCookieResponse, | 		ValidToken:   opts.providerValidateCookieResponse, | ||||||
| 	} | 	} | ||||||
| 	pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.Providers[0].AllowedGroups) | 
 | ||||||
|  | 	groups := pcTest.opts.Providers[0].AllowedGroups | ||||||
|  | 	testProvider.AllowedGroups = make(map[string]struct{}, len(groups)) | ||||||
|  | 	for _, group := range groups { | ||||||
|  | 		testProvider.AllowedGroups[group] = struct{}{} | ||||||
|  | 	} | ||||||
|  | 	pcTest.proxy.provider = testProvider | ||||||
| 
 | 
 | ||||||
| 	// Now, zero-out proxy.CookieRefresh for the cases that don't involve
 | 	// Now, zero-out proxy.CookieRefresh for the cases that don't involve
 | ||||||
| 	// access_token validation.
 | 	// access_token validation.
 | ||||||
|  | @ -1359,12 +1363,12 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) { | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 	upstreamURL, _ := url.Parse(upstreamServer.URL) | 	upstreamURL, _ := url.Parse(upstreamServer.URL) | ||||||
| 	opts.SetProvider(NewTestProvider(upstreamURL, "")) |  | ||||||
| 
 | 
 | ||||||
| 	proxy, err := NewOAuthProxy(opts, func(string) bool { return false }) | 	proxy, err := NewOAuthProxy(opts, func(string) bool { return false }) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
|  | 	proxy.provider = NewTestProvider(upstreamURL, "") | ||||||
| 	rw := httptest.NewRecorder() | 	rw := httptest.NewRecorder() | ||||||
| 	req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil) | 	req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil) | ||||||
| 	proxy.ServeHTTP(rw, req) | 	proxy.ServeHTTP(rw, req) | ||||||
|  | @ -1409,6 +1413,7 @@ type SignatureTest struct { | ||||||
| 	header        http.Header | 	header        http.Header | ||||||
| 	rw            *httptest.ResponseRecorder | 	rw            *httptest.ResponseRecorder | ||||||
| 	authenticator *SignatureAuthenticator | 	authenticator *SignatureAuthenticator | ||||||
|  | 	authProvider  providers.Provider | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewSignatureTest() (*SignatureTest, error) { | func NewSignatureTest() (*SignatureTest, error) { | ||||||
|  | @ -1443,7 +1448,7 @@ func NewSignatureTest() (*SignatureTest, error) { | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	opts.SetProvider(NewTestProvider(providerURL, "mbland@acm.org")) | 	testProvider := NewTestProvider(providerURL, "mbland@acm.org") | ||||||
| 
 | 
 | ||||||
| 	return &SignatureTest{ | 	return &SignatureTest{ | ||||||
| 		opts, | 		opts, | ||||||
|  | @ -1453,6 +1458,7 @@ func NewSignatureTest() (*SignatureTest, error) { | ||||||
| 		make(http.Header), | 		make(http.Header), | ||||||
| 		httptest.NewRecorder(), | 		httptest.NewRecorder(), | ||||||
| 		authenticator, | 		authenticator, | ||||||
|  | 		testProvider, | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -1486,6 +1492,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) er | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  | 	proxy.provider = st.authProvider | ||||||
| 
 | 
 | ||||||
| 	var bodyBuf io.ReadCloser | 	var bodyBuf io.ReadCloser | ||||||
| 	if body != "" { | 	if body != "" { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue