From 0791aef8cc5604990d3464dc7c533c32e8680069 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Tue, 15 Feb 2022 12:00:06 +0000 Subject: [PATCH] Integrate new provider constructor in main --- oauthproxy.go | 21 +++++++++++++-------- oauthproxy_test.go | 25 ++++++++++++++++--------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index 1cd35e65..12996d8f 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -114,6 +114,11 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr } } + provider, err := providers.NewProvider(opts.Providers[0]) + if err != nil { + return nil, fmt.Errorf("error intiailising provider: %v", err) + } + pageWriter, err := pagewriter.NewWriter(pagewriter.Opts{ TemplatesPath: opts.Templates.Path, CustomLogo: opts.Templates.CustomLogo, @@ -121,7 +126,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr Footer: opts.Templates.Footer, Version: VERSION, Debug: opts.Templates.Debug, - ProviderName: buildProviderName(opts.GetProvider(), opts.Providers[0].Name), + ProviderName: buildProviderName(provider, opts.Providers[0].Name), SignInMessage: buildSignInMessage(opts), DisplayLoginForm: basicAuthValidator != nil && opts.Templates.DisplayLoginForm, }) @@ -145,7 +150,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr redirectURL.Path = fmt.Sprintf("%s/callback", opts.ProxyPrefix) } - logger.Printf("OAuthProxy configured for %s Client ID: %s", opts.GetProvider().Data().ProviderName, opts.Providers[0].ClientID) + logger.Printf("OAuthProxy configured for %s Client ID: %s", provider.Data().ProviderName, opts.Providers[0].ClientID) refresh := "disabled" if opts.Cookie.Refresh != time.Duration(0) { refresh = fmt.Sprintf("after %s", opts.Cookie.Refresh) @@ -171,7 +176,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr if err != nil { return nil, fmt.Errorf("could not build pre-auth chain: %v", err) } - sessionChain := buildSessionChain(opts, sessionStore, basicAuthValidator) + sessionChain := buildSessionChain(opts, provider, sessionStore, basicAuthValidator) headersChain, err := buildHeadersChain(opts) if err != nil { return nil, fmt.Errorf("could not build headers chain: %v", err) @@ -190,7 +195,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix), ProxyPrefix: opts.ProxyPrefix, - provider: opts.GetProvider(), + provider: provider, sessionStore: sessionStore, redirectURL: redirectURL, allowedRoutes: allowedRoutes, @@ -346,12 +351,12 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { return chain, nil } -func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain { +func buildSessionChain(opts *options.Options, provider providers.Provider, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain { chain := alice.New() if opts.SkipJwtBearerTokens { sessionLoaders := []middlewareapi.TokenToSessionFunc{ - opts.GetProvider().CreateSessionFromToken, + provider.CreateSessionFromToken, } for _, verifier := range opts.GetJWTBearerVerifiers() { @@ -369,8 +374,8 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ SessionStore: sessionStore, RefreshPeriod: opts.Cookie.Refresh, - RefreshSession: opts.GetProvider().RefreshSession, - ValidateSession: opts.GetProvider().ValidateSession, + RefreshSession: provider.RefreshSession, + ValidateSession: provider.ValidateSession, })) return chain diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 6f5a5538..3157c340 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -161,13 +161,11 @@ func Test_enrichSession(t *testing.T) { err := validation.Validate(opts) assert.NoError(t, err) - // intentionally set after validation.Validate(opts) since it will clobber - // our TestProvider and call `providers.New` defaulting to `providers.GoogleProvider` - opts.SetProvider(NewTestProvider(&url.URL{Host: "www.example.com"}, providerEmail)) proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) if err != nil { t.Fatal(err) } + proxy.provider = NewTestProvider(&url.URL{Host: "www.example.com"}, providerEmail) err = proxy.enrichSessionState(context.Background(), tc.session) assert.NoError(t, err) @@ -232,13 +230,13 @@ func TestBasicAuthPassword(t *testing.T) { providerURL, _ := url.Parse(providerServer.URL) const emailAddress = "john.doe@example.com" - opts.SetProvider(NewTestProvider(providerURL, emailAddress)) proxy, err := NewOAuthProxy(opts, func(email string) bool { return email == emailAddress }) if err != nil { t.Fatal(err) } + proxy.provider = NewTestProvider(providerURL, emailAddress) // Save the required session rw := httptest.NewRecorder() @@ -390,10 +388,10 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTe testProvider := NewTestProvider(providerURL, emailAddress) testProvider.ValidToken = opts.ValidToken - patt.opts.SetProvider(testProvider) patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool { return email == emailAddress }) + patt.proxy.provider = testProvider if err != nil { return nil, err } @@ -769,11 +767,17 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi if err != nil { return nil, err } - pcTest.proxy.provider = &TestProvider{ + testProvider := &TestProvider{ ProviderData: &providers.ProviderData{}, ValidToken: opts.providerValidateCookieResponse, } - pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.Providers[0].AllowedGroups) + + groups := pcTest.opts.Providers[0].AllowedGroups + testProvider.AllowedGroups = make(map[string]struct{}, len(groups)) + for _, group := range groups { + testProvider.AllowedGroups[group] = struct{}{} + } + pcTest.proxy.provider = testProvider // Now, zero-out proxy.CookieRefresh for the cases that don't involve // access_token validation. @@ -1359,12 +1363,12 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) { assert.NoError(t, err) upstreamURL, _ := url.Parse(upstreamServer.URL) - opts.SetProvider(NewTestProvider(upstreamURL, "")) proxy, err := NewOAuthProxy(opts, func(string) bool { return false }) if err != nil { t.Fatal(err) } + proxy.provider = NewTestProvider(upstreamURL, "") rw := httptest.NewRecorder() req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil) proxy.ServeHTTP(rw, req) @@ -1409,6 +1413,7 @@ type SignatureTest struct { header http.Header rw *httptest.ResponseRecorder authenticator *SignatureAuthenticator + authProvider providers.Provider } func NewSignatureTest() (*SignatureTest, error) { @@ -1443,7 +1448,7 @@ func NewSignatureTest() (*SignatureTest, error) { if err != nil { return nil, err } - opts.SetProvider(NewTestProvider(providerURL, "mbland@acm.org")) + testProvider := NewTestProvider(providerURL, "mbland@acm.org") return &SignatureTest{ opts, @@ -1453,6 +1458,7 @@ func NewSignatureTest() (*SignatureTest, error) { make(http.Header), httptest.NewRecorder(), authenticator, + testProvider, }, nil } @@ -1486,6 +1492,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) er if err != nil { return err } + proxy.provider = st.authProvider var bodyBuf io.ReadCloser if body != "" {