From a1877434b21984b1a57839c2c31dca505bb7884e Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Thu, 26 Nov 2020 19:00:30 -0800 Subject: [PATCH 01/28] Refactor OIDC to EnrichSession --- pkg/validation/options.go | 2 +- providers/oidc.go | 328 +++++++++++++++++++------------------- providers/oidc_test.go | 284 +++++++++++++++++++++++++++++---- providers/util.go | 23 +++ providers/util_test.go | 49 +++++- 5 files changed, 489 insertions(+), 197 deletions(-) diff --git a/pkg/validation/options.go b/pkg/validation/options.go index fc89000b..0cea4b2b 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -274,7 +274,7 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { p.SetRepository(o.BitbucketRepository) case *providers.OIDCProvider: p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail - p.UserIDClaim = o.UserIDClaim + p.EmailClaim = o.UserIDClaim p.GroupsClaim = o.OIDCGroupsClaim if p.Verifier == nil { msgs = append(msgs, "oidc provider requires an oidc issuer URL") diff --git a/providers/oidc.go b/providers/oidc.go index 704e3341..15020282 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -2,18 +2,17 @@ package providers import ( "context" - "encoding/json" + "errors" "fmt" "reflect" "strings" "time" - oidc "github.com/coreos/go-oidc" - "golang.org/x/oauth2" - + "github.com/coreos/go-oidc" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" + "golang.org/x/oauth2" ) const emailClaim = "email" @@ -23,7 +22,7 @@ type OIDCProvider struct { *ProviderData AllowUnverifiedEmail bool - UserIDClaim string + EmailClaim string GroupsClaim string } @@ -36,10 +35,10 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider { var _ Provider = (*OIDCProvider)(nil) // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { clientSecret, err := p.GetClientSecret() if err != nil { - return + return nil, err } c := oauth2.Config{ @@ -52,23 +51,74 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s } token, err := c.Exchange(ctx, code) if err != nil { - return nil, fmt.Errorf("token exchange: %v", err) + return nil, fmt.Errorf("token exchange failure: %v", err) } - // in the initial exchange the id token is mandatory - idToken, err := p.findVerifiedIDToken(ctx, token) + return p.createSession(ctx, token, false) +} + +// EnrichSessionState is called after Redeem to allow providers to enrich session fields +// such as User, Email, Groups with provider specific API calls. +func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { + if p.ProfileURL.String() == "" { + if s.Email == "" { + return errors.New("id_token did not contain an email and profileURL is not defined") + } + return nil + } + + // Try to get missing emails or groups from a profileURL + if s.Email == "" || len(s.Groups) == 0 { + err := p.callProfileURL(ctx, s) + if err != nil { + logger.Errorf("Warning: Profile URL request failed: %v", err) + } + } + + // If a mandatory email wasn't set, error at this point. + if s.Email == "" { + return errors.New("neither the id_token nor the profileURL set an email") + } + return nil +} + +// callProfileURL enriches a session's Email & Groups via the JSON response of +// an OIDC profile URL +func (p *OIDCProvider) callProfileURL(ctx context.Context, s *sessions.SessionState) error { + respJSON, err := requests.New(p.ProfileURL.String()). + WithContext(ctx). + WithHeaders(makeOIDCHeader(s.AccessToken)). + Do(). + UnmarshalJSON() if err != nil { - return nil, fmt.Errorf("could not verify id_token: %v", err) - } else if idToken == nil { - return nil, fmt.Errorf("token response did not contain an id_token") + return err } - s, err = p.createSessionState(ctx, token, idToken) - if err != nil { - return nil, fmt.Errorf("unable to update session: %v", err) + email, err := respJSON.Get(p.EmailClaim).String() + if err == nil && s.Email == "" { + s.Email = email } - return + // Handle array & singleton groups cases + if len(s.Groups) == 0 { + groups, err := respJSON.Get(p.GroupsClaim).StringArray() + if err == nil { + s.Groups = groups + } else { + group, err := respJSON.Get(p.GroupsClaim).String() + if err == nil { + s.Groups = []string{group} + } + } + } + + return nil +} + +// ValidateSessionState checks that the session's IDToken is still valid +func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { + _, err := p.Verifier.Verify(ctx, s.IDToken) + return err == nil } // RefreshSessionIfNeeded checks if the session has expired and uses the @@ -83,14 +133,16 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S return false, fmt.Errorf("unable to redeem refresh token: %v", err) } - fmt.Printf("refreshed access token %s (expired on %s)\n", s, s.ExpiresOn) + logger.Printf("refreshed session: %s", s) return true, nil } -func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { +// redeemRefreshToken uses a RefreshToken with the RedeemURL to refresh the +// Access Token and (probably) the ID Token. +func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { clientSecret, err := p.GetClientSecret() if err != nil { - return + return err } c := oauth2.Config{ @@ -109,19 +161,14 @@ func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sessi return fmt.Errorf("failed to get token: %v", err) } - // in the token refresh response the id_token is optional - idToken, err := p.findVerifiedIDToken(ctx, token) - if err != nil { - return fmt.Errorf("unable to extract id_token from response: %v", err) - } - - newSession, err := p.createSessionState(ctx, token, idToken) + newSession, err := p.createSession(ctx, token, true) if err != nil { return fmt.Errorf("unable create new session state from response: %v", err) } - // It's possible that if the refresh token isn't in the token response the session will not contain an id token - // if it doesn't it's probably better to retain the old one + // It's possible that if the refresh token isn't in the token response the + // session will not contain an id token. + // If it doesn't it's probably better to retain the old one if newSession.IDToken != "" { s.IDToken = newSession.IDToken s.Email = newSession.Email @@ -135,102 +182,113 @@ func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sessi s.CreatedAt = newSession.CreatedAt s.ExpiresOn = newSession.ExpiresOn - return -} - -func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { - - getIDToken := func() (string, bool) { - rawIDToken, _ := token.Extra("id_token").(string) - return rawIDToken, len(strings.TrimSpace(rawIDToken)) > 0 - } - - if rawIDToken, present := getIDToken(); present { - verifiedIDToken, err := p.Verifier.Verify(ctx, rawIDToken) - return verifiedIDToken, err - } - return nil, nil -} - -func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) { - - var newSession *sessions.SessionState - - if idToken == nil { - newSession = &sessions.SessionState{} - } else { - var err error - newSession, err = p.createSessionStateInternal(ctx, idToken, token) - if err != nil { - return nil, err - } - } - - created := time.Now() - newSession.AccessToken = token.AccessToken - newSession.RefreshToken = token.RefreshToken - newSession.CreatedAt = &created - newSession.ExpiresOn = &token.Expiry - return newSession, nil + return nil } +// CreateSessionFromToken converts Bearer IDTokens into sessions func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) { idToken, err := p.Verifier.Verify(ctx, token) if err != nil { return nil, err } - newSession, err := p.createSessionStateInternal(ctx, idToken, nil) + ss, err := p.buildSessionFromClaims(idToken) if err != nil { return nil, err } - newSession.AccessToken = token - newSession.IDToken = token - newSession.RefreshToken = "" - newSession.ExpiresOn = &idToken.Expiry - - return newSession, nil -} - -func (p *OIDCProvider) createSessionStateInternal(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) { - - newSession := &sessions.SessionState{} - - if idToken == nil { - return newSession, nil + // Allow empty Email in Bearer case since we can't hit the ProfileURL + if ss.Email == "" { + ss.Email = ss.User } - claims, err := p.findClaimsFromIDToken(ctx, idToken, token) + ss.AccessToken = token + ss.IDToken = token + ss.RefreshToken = "" + ss.ExpiresOn = &idToken.Expiry + + return ss, nil +} + +// createSession takes an oauth2.Token and creates a SessionState from it. +// It alters behavior if called from Redeem vs Refresh +func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) { + idToken, err := p.findVerifiedIDToken(ctx, token) + if err != nil { + return nil, fmt.Errorf("could not verify id_token: %v", err) + } + + // IDToken is mandatory in Redeem but optional in Refresh + if idToken == nil && !refresh { + return nil, errors.New("token response did not contain an id_token") + } + + ss, err := p.buildSessionFromClaims(idToken) + if err != nil { + return nil, err + } + + ss.AccessToken = token.AccessToken + ss.RefreshToken = token.RefreshToken + ss.IDToken = getIDToken(token) + + created := time.Now() + ss.CreatedAt = &created + ss.ExpiresOn = &token.Expiry + + return ss, nil +} + +func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { + rawIDToken := getIDToken(token) + if strings.TrimSpace(rawIDToken) != "" { + return p.Verifier.Verify(ctx, rawIDToken) + } + return nil, nil +} + +// buildSessionFromClaims uses IDToken claims to populate a fresh SessionState +// with non-Token related fields. +func (p *OIDCProvider) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions.SessionState, error) { + ss := &sessions.SessionState{} + + if idToken == nil { + return ss, nil + } + + claims, err := p.getClaims(idToken) if err != nil { return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err) } - if token != nil { - newSession.IDToken = token.Extra("id_token").(string) + ss.User = claims.Subject + ss.Email = claims.Email + ss.Groups = claims.Groups + + // TODO (@NickMeves) Deprecate for dynamic claim to session mapping + if pref, ok := claims.rawClaims["preferred_username"].(string); ok { + ss.PreferredUsername = pref } - newSession.Email = claims.UserID // TODO Rename SessionState.Email to .UserID in the near future - - newSession.User = claims.Subject - newSession.Groups = claims.Groups - newSession.PreferredUsername = claims.PreferredUsername - - verifyEmail := (p.UserIDClaim == emailClaim) && !p.AllowUnverifiedEmail + verifyEmail := (p.EmailClaim == emailClaim) && !p.AllowUnverifiedEmail if verifyEmail && claims.Verified != nil && !*claims.Verified { - return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.UserID) + return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) } - return newSession, nil + return ss, nil } -// ValidateSessionState checks that the session's IDToken is still valid -func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { - _, err := p.Verifier.Verify(ctx, s.IDToken) - return err == nil +type OIDCClaims struct { + Subject string `json:"sub"` + Email string `json:"-"` + Groups []string `json:"-"` + Verified *bool `json:"email_verified"` + + rawClaims map[string]interface{} } -func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*OIDCClaims, error) { +// getClaims extracts IDToken claims into an OIDCClaims +func (p *OIDCProvider) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) { claims := &OIDCClaims{} // Extract default claims. @@ -242,86 +300,28 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. return nil, fmt.Errorf("failed to parse all id_token claims: %v", err) } - userID := claims.rawClaims[p.UserIDClaim] - if userID != nil { - claims.UserID = fmt.Sprint(userID) - } - - claims.Groups = p.extractGroupsFromRawClaims(claims.rawClaims) - - // userID claim was not present or was empty in the ID Token - if claims.UserID == "" { - // BearerToken case, allow empty UserID - // ProfileURL checks below won't work since we don't have an access token - if token == nil { - claims.UserID = claims.Subject - return claims, nil - } - - profileURL := p.ProfileURL.String() - if profileURL == "" || token.AccessToken == "" { - return nil, fmt.Errorf("id_token did not contain user ID claim (%q)", p.UserIDClaim) - } - - // If the userinfo endpoint profileURL is defined, then there is a chance the userinfo - // contents at the profileURL contains the email. - // Make a query to the userinfo endpoint, and attempt to locate the email from there. - respJSON, err := requests.New(profileURL). - WithContext(ctx). - WithHeaders(makeOIDCHeader(token.AccessToken)). - Do(). - UnmarshalJSON() - if err != nil { - return nil, err - } - - userID, err := respJSON.Get(p.UserIDClaim).String() - if err != nil { - return nil, fmt.Errorf("neither id_token nor userinfo endpoint contained user ID claim (%q)", p.UserIDClaim) - } - - claims.UserID = userID + email := claims.rawClaims[p.EmailClaim] + if email != nil { + claims.Email = fmt.Sprint(email) } + claims.Groups = p.extractGroups(claims.rawClaims) return claims, nil } -func (p *OIDCProvider) extractGroupsFromRawClaims(rawClaims map[string]interface{}) []string { +func (p *OIDCProvider) extractGroups(claims map[string]interface{}) []string { groups := []string{} - - rawGroups, ok := rawClaims[p.GroupsClaim].([]interface{}) + rawGroups, ok := claims[p.GroupsClaim].([]interface{}) if rawGroups != nil && ok { for _, rawGroup := range rawGroups { formattedGroup, err := formatGroup(rawGroup) if err != nil { - logger.Errorf("unable to format group of type %s with error %s", reflect.TypeOf(rawGroup), err) + logger.Errorf("Warning: unable to format group of type %s with error %s", + reflect.TypeOf(rawGroup), err) continue } groups = append(groups, formattedGroup) } } - return groups - -} - -func formatGroup(rawGroup interface{}) (string, error) { - group, ok := rawGroup.(string) - if !ok { - jsonGroup, err := json.Marshal(rawGroup) - if err != nil { - return "", err - } - group = string(jsonGroup) - } - return group, nil -} - -type OIDCClaims struct { - rawClaims map[string]interface{} - UserID string - Subject string `json:"sub"` - Verified *bool `json:"email_verified"` - PreferredUsername string `json:"preferred_username"` - Groups []string `json:"-"` } diff --git a/providers/oidc_test.go b/providers/oidc_test.go index 54df1b31..0cc6ef76 100644 --- a/providers/oidc_test.go +++ b/providers/oidc_test.go @@ -154,7 +154,7 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider { p := &OIDCProvider{ ProviderData: providerData, - UserIDClaim: "email", + EmailClaim: "email", } return p @@ -225,7 +225,7 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) { }) server, provider := newTestSetup(body) - provider.UserIDClaim = "phone_number" + provider.EmailClaim = "phone_number" defer server.Close() session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234") @@ -233,6 +233,256 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) { assert.Equal(t, defaultIDToken.Phone, session.Email) } +func TestOIDCProvider_EnrichSession(t *testing.T) { + const ( + idToken = "Unchanged ID Token" + accessToken = "Unchanged Access Token" + refreshToken = "Unchanged Refresh Token" + ) + + testCases := map[string]struct { + ExistingSession *sessions.SessionState + EmailClaim string + GroupsClaim string + ProfileJSON map[string]interface{} + ExpectedError error + ExpectedSession *sessions.SessionState + }{ + "Already Populated": { + ExistingSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{"already", "populated"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "new@thing.com", + "groups": []string{"new", "thing"}, + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{"already", "populated"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Missing Email": { + ExistingSession: &sessions.SessionState{ + User: "missing.email", + Groups: []string{"already", "populated"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "found@email.com", + "groups": []string{"new", "thing"}, + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "missing.email", + Email: "found@email.com", + Groups: []string{"already", "populated"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + + "Missing Email Only in Profile URL": { + ExistingSession: &sessions.SessionState{ + User: "missing.email", + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "found@email.com", + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "missing.email", + Email: "found@email.com", + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Missing Email with Custom Claim": { + ExistingSession: &sessions.SessionState{ + User: "missing.email", + Groups: []string{"already", "populated"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "weird", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "weird": "weird@claim.com", + "groups": []string{"new", "thing"}, + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "missing.email", + Email: "weird@claim.com", + Groups: []string{"already", "populated"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Missing Email not in Profile URL": { + ExistingSession: &sessions.SessionState{ + User: "missing.email", + Groups: []string{"already", "populated"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "groups": []string{"new", "thing"}, + }, + ExpectedError: errors.New("neither the id_token nor the profileURL set an email"), + ExpectedSession: &sessions.SessionState{ + User: "missing.email", + Groups: []string{"already", "populated"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Missing Groups": { + ExistingSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "new@thing.com", + "groups": []string{"new", "thing"}, + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{"new", "thing"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Missing Groups with Custom Claim": { + ExistingSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: nil, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "roles", + ProfileJSON: map[string]interface{}{ + "email": "new@thing.com", + "roles": []string{"new", "thing", "roles"}, + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{"new", "thing", "roles"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Missing Groups String Profile URL Response": { + ExistingSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "new@thing.com", + "groups": "singleton", + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{"singleton"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Missing Groups in both Claims and Profile URL": { + ExistingSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "new@thing.com", + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + } + for testName, tc := range testCases { + t.Run(testName, func(t *testing.T) { + jsonResp, err := json.Marshal(tc.ProfileJSON) + assert.NoError(t, err) + + server, provider := newTestSetup(jsonResp) + provider.ProfileURL, err = url.Parse(server.URL) + assert.NoError(t, err) + + provider.EmailClaim = tc.EmailClaim + provider.GroupsClaim = tc.GroupsClaim + defer server.Close() + + err = provider.EnrichSession(context.Background(), tc.ExistingSession) + assert.Equal(t, tc.ExpectedError, err) + assert.Equal(t, *tc.ExpectedSession, *tc.ExistingSession) + }) + } +} + func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { idToken, _ := newSignedTestIDToken(defaultIDToken) @@ -361,7 +611,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) { } } -func TestOIDCProvider_findVerifiedIdToken(t *testing.T) { +func TestOIDCProvider_findVerifiedIDToken(t *testing.T) { server, provider := newTestSetup([]byte("")) @@ -397,31 +647,3 @@ func TestOIDCProvider_findVerifiedIdToken(t *testing.T) { assert.Equal(t, nil, err) assert.Equal(t, true, verifiedIDToken == nil) } - -func Test_formatGroup(t *testing.T) { - testCases := map[string]struct { - RawGroup interface{} - ExpectedFormattedGroupValue string - }{ - "String Group": { - RawGroup: "group", - ExpectedFormattedGroupValue: "group", - }, - "Map Group": { - RawGroup: map[string]string{"id": "1", "name": "Test"}, - ExpectedFormattedGroupValue: "{\"id\":\"1\",\"name\":\"Test\"}", - }, - "List Group": { - RawGroup: []string{"First", "Second"}, - ExpectedFormattedGroupValue: "[\"First\",\"Second\"]", - }, - } - - for testName, tc := range testCases { - t.Run(testName, func(t *testing.T) { - formattedGroup, err := formatGroup(tc.RawGroup) - assert.Nil(t, err) - assert.Equal(t, tc.ExpectedFormattedGroupValue, formattedGroup) - }) - } -} diff --git a/providers/util.go b/providers/util.go index b4b65ac6..acf20902 100644 --- a/providers/util.go +++ b/providers/util.go @@ -1,9 +1,12 @@ package providers import ( + "encoding/json" "fmt" "net/http" "net/url" + + "golang.org/x/oauth2" ) const ( @@ -55,3 +58,23 @@ func makeLoginURL(p *ProviderData, redirectURI, state string, extraParams url.Va a.RawQuery = params.Encode() return a } + +func getIDToken(token *oauth2.Token) string { + idToken, ok := token.Extra("id_token").(string) + if !ok { + return "" + } + return idToken +} + +func formatGroup(rawGroup interface{}) (string, error) { + group, ok := rawGroup.(string) + if !ok { + jsonGroup, err := json.Marshal(rawGroup) + if err != nil { + return "", err + } + group = string(jsonGroup) + } + return group, nil +} diff --git a/providers/util_test.go b/providers/util_test.go index 798df6cb..e14ff061 100644 --- a/providers/util_test.go +++ b/providers/util_test.go @@ -5,9 +5,10 @@ import ( "testing" . "github.com/onsi/gomega" + "golang.org/x/oauth2" ) -func TestMakeAuhtorizationHeader(t *testing.T) { +func Test_makeAuthorizationHeader(t *testing.T) { testCases := []struct { name string prefix string @@ -64,3 +65,49 @@ func TestMakeAuhtorizationHeader(t *testing.T) { }) } } + +func Test_getIDToken(t *testing.T) { + const idToken = "eyJfoobar.eyJfoobar.12345asdf" + g := NewWithT(t) + + token := &oauth2.Token{} + g.Expect(getIDToken(token)).To(Equal("")) + + extraToken := token.WithExtra(map[string]interface{}{ + "id_token": idToken, + }) + g.Expect(getIDToken(extraToken)).To(Equal(idToken)) +} + +func Test_formatGroup(t *testing.T) { + testCases := map[string]struct { + rawGroup interface{} + expected string + }{ + "String Group": { + rawGroup: "group", + expected: "group", + }, + "Numeric Group": { + rawGroup: 123, + expected: "123", + }, + "Map Group": { + rawGroup: map[string]string{"id": "1", "name": "Test"}, + expected: "{\"id\":\"1\",\"name\":\"Test\"}", + }, + "List Group": { + rawGroup: []string{"First", "Second"}, + expected: "[\"First\",\"Second\"]", + }, + } + + for testName, tc := range testCases { + t.Run(testName, func(t *testing.T) { + g := NewWithT(t) + formattedGroup, err := formatGroup(tc.rawGroup) + g.Expect(err).ToNot(HaveOccurred()) + g.Expect(formattedGroup).To(Equal(tc.expected)) + }) + } +} From 74ac4274c6643cb205336a86f6ef40d16805c59a Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Fri, 27 Nov 2020 16:17:59 -0800 Subject: [PATCH 02/28] Move generic OIDC functionality to be available to all providers --- pkg/validation/options.go | 5 +- providers/oidc.go | 98 +------- providers/oidc_test.go | 206 ++-------------- providers/provider_data.go | 109 ++++++++- providers/provider_data_test.go | 414 ++++++++++++++++++++++++++++++++ providers/provider_default.go | 2 + 6 files changed, 551 insertions(+), 283 deletions(-) create mode 100644 providers/provider_data_test.go diff --git a/pkg/validation/options.go b/pkg/validation/options.go index 0cea4b2b..652ada9e 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -233,7 +233,10 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) - // Make the OIDC Verifier accessible to all providers that can support it + // Make the OIDC options available to all providers that support it + p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail + p.EmailClaim = o.UserIDClaim + p.GroupsClaim = o.OIDCGroupsClaim p.Verifier = o.GetOIDCVerifier() p.SetAllowedGroups(o.AllowedGroups) diff --git a/providers/oidc.go b/providers/oidc.go index 15020282..f90348d6 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -4,26 +4,17 @@ import ( "context" "errors" "fmt" - "reflect" - "strings" "time" - "github.com/coreos/go-oidc" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" "golang.org/x/oauth2" ) -const emailClaim = "email" - // OIDCProvider represents an OIDC based Identity Provider type OIDCProvider struct { *ProviderData - - AllowUnverifiedEmail bool - EmailClaim string - GroupsClaim string } // NewOIDCProvider initiates a new OIDCProvider @@ -213,7 +204,7 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) // createSession takes an oauth2.Token and creates a SessionState from it. // It alters behavior if called from Redeem vs Refresh func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) { - idToken, err := p.findVerifiedIDToken(ctx, token) + idToken, err := p.verifyIDToken(ctx, token) if err != nil { return nil, fmt.Errorf("could not verify id_token: %v", err) } @@ -238,90 +229,3 @@ func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, r return ss, nil } - -func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { - rawIDToken := getIDToken(token) - if strings.TrimSpace(rawIDToken) != "" { - return p.Verifier.Verify(ctx, rawIDToken) - } - return nil, nil -} - -// buildSessionFromClaims uses IDToken claims to populate a fresh SessionState -// with non-Token related fields. -func (p *OIDCProvider) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions.SessionState, error) { - ss := &sessions.SessionState{} - - if idToken == nil { - return ss, nil - } - - claims, err := p.getClaims(idToken) - if err != nil { - return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err) - } - - ss.User = claims.Subject - ss.Email = claims.Email - ss.Groups = claims.Groups - - // TODO (@NickMeves) Deprecate for dynamic claim to session mapping - if pref, ok := claims.rawClaims["preferred_username"].(string); ok { - ss.PreferredUsername = pref - } - - verifyEmail := (p.EmailClaim == emailClaim) && !p.AllowUnverifiedEmail - if verifyEmail && claims.Verified != nil && !*claims.Verified { - return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) - } - - return ss, nil -} - -type OIDCClaims struct { - Subject string `json:"sub"` - Email string `json:"-"` - Groups []string `json:"-"` - Verified *bool `json:"email_verified"` - - rawClaims map[string]interface{} -} - -// getClaims extracts IDToken claims into an OIDCClaims -func (p *OIDCProvider) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) { - claims := &OIDCClaims{} - - // Extract default claims. - if err := idToken.Claims(&claims); err != nil { - return nil, fmt.Errorf("failed to parse default id_token claims: %v", err) - } - // Extract custom claims. - if err := idToken.Claims(&claims.rawClaims); err != nil { - return nil, fmt.Errorf("failed to parse all id_token claims: %v", err) - } - - email := claims.rawClaims[p.EmailClaim] - if email != nil { - claims.Email = fmt.Sprint(email) - } - claims.Groups = p.extractGroups(claims.rawClaims) - - return claims, nil -} - -func (p *OIDCProvider) extractGroups(claims map[string]interface{}) []string { - groups := []string{} - rawGroups, ok := claims[p.GroupsClaim].([]interface{}) - if rawGroups != nil && ok { - for _, rawGroup := range rawGroups { - formattedGroup, err := formatGroup(rawGroup) - if err != nil { - logger.Errorf("Warning: unable to format group of type %s with error %s", - reflect.TypeOf(rawGroup), err) - continue - } - groups = append(groups, formattedGroup) - } - } - return groups -} diff --git a/providers/oidc_test.go b/providers/oidc_test.go index 0cc6ef76..2651b4ea 100644 --- a/providers/oidc_test.go +++ b/providers/oidc_test.go @@ -2,42 +2,18 @@ package providers import ( "context" - "crypto/rand" - "crypto/rsa" - "encoding/base64" "encoding/json" "errors" - "fmt" "net/http" "net/http/httptest" "net/url" - "strings" "testing" - "time" "github.com/coreos/go-oidc" - "github.com/dgrijalva/jwt-go" - "github.com/stretchr/testify/assert" - "golang.org/x/oauth2" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + "github.com/stretchr/testify/assert" ) -const accessToken = "access_token" -const refreshToken = "refresh_token" -const clientID = "https://test.myapp.com" -const secret = "secret" - -type idTokenClaims struct { - Name string `json:"name,omitempty"` - Email string `json:"email,omitempty"` - Phone string `json:"phone_number,omitempty"` - Picture string `json:"picture,omitempty"` - Groups interface{} `json:"groups,omitempty"` - OtherGroups interface{} `json:"other_groups,omitempty"` - jwt.StandardClaims -} - type redeemTokenResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` @@ -46,88 +22,12 @@ type redeemTokenResponse struct { IDToken string `json:"id_token,omitempty"` } -var defaultIDToken idTokenClaims = idTokenClaims{ - "Jane Dobbs", - "janed@me.com", - "+4798765432", - "http://mugbook.com/janed/me.jpg", - []string{"test:a", "test:b"}, - []string{"test:c", "test:d"}, - jwt.StandardClaims{ - Audience: "https://test.myapp.com", - ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(), - Id: "id-some-id", - IssuedAt: time.Now().Unix(), - Issuer: "https://issuer.example.com", - NotBefore: 0, - Subject: "123456789", - }, -} - -var customGroupClaimIDToken idTokenClaims = idTokenClaims{ - "Jane Dobbs", - "janed@me.com", - "+4798765432", - "http://mugbook.com/janed/me.jpg", - []map[string]interface{}{ - { - "groupId": "Admin Group Id", - "roles": []string{"Admin"}, - }, - }, - []string{"test:c", "test:d"}, - jwt.StandardClaims{ - Audience: "https://test.myapp.com", - ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(), - Id: "id-some-id", - IssuedAt: time.Now().Unix(), - Issuer: "https://issuer.example.com", - NotBefore: 0, - Subject: "123456789", - }, -} - -var minimalIDToken idTokenClaims = idTokenClaims{ - "", - "", - "", - "", - []string{}, - []string{}, - jwt.StandardClaims{ - Audience: "https://test.myapp.com", - ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(), - Id: "id-some-id", - IssuedAt: time.Now().Unix(), - Issuer: "https://issuer.example.com", - NotBefore: 0, - Subject: "minimal", - }, -} - -type fakeKeySetStub struct{} - -func (fakeKeySetStub) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) { - decodeString, err := base64.RawURLEncoding.DecodeString(strings.Split(jwt, ".")[1]) - if err != nil { - return nil, err - } - tokenClaims := &idTokenClaims{} - err = json.Unmarshal(decodeString, tokenClaims) - - if err != nil || tokenClaims.Id == "this-id-fails-validation" { - return nil, fmt.Errorf("the validation failed for subject [%v]", tokenClaims.Subject) - } - - return decodeString, err -} - func newOIDCProvider(serverURL *url.URL) *OIDCProvider { providerData := &ProviderData{ ProviderName: "oidc", - ClientID: clientID, - ClientSecret: secret, + ClientID: oidcClientID, + ClientSecret: oidcSecret, LoginURL: &url.URL{ Scheme: serverURL.Scheme, Host: serverURL.Host, @@ -144,18 +44,17 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider { Scheme: serverURL.Scheme, Host: serverURL.Host, Path: "/api"}, - Scope: "openid profile offline_access", + Scope: "openid profile offline_access", + EmailClaim: "email", + GroupsClaim: "groups", Verifier: oidc.NewVerifier( - "https://issuer.example.com", - fakeKeySetStub{}, - &oidc.Config{ClientID: clientID}, + oidcIssuer, + mockJWKS{}, + &oidc.Config{ClientID: oidcClientID}, ), } - p := &OIDCProvider{ - ProviderData: providerData, - EmailClaim: "email", - } + p := &OIDCProvider{ProviderData: providerData} return p } @@ -169,21 +68,6 @@ func newOIDCServer(body []byte) (*url.URL, *httptest.Server) { return u, s } -func newSignedTestIDToken(tokenClaims idTokenClaims) (string, error) { - key, _ := rsa.GenerateKey(rand.Reader, 2048) - standardClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, tokenClaims) - return standardClaims.SignedString(key) -} - -func newOauth2Token() *oauth2.Token { - return &oauth2.Token{ - AccessToken: accessToken, - TokenType: "Bearer", - RefreshToken: refreshToken, - Expiry: time.Time{}.Add(time.Duration(5) * time.Second), - } -} - func newTestSetup(body []byte) (*httptest.Server, *OIDCProvider) { redeemURL, server := newOIDCServer(body) provider := newOIDCProvider(redeemURL) @@ -234,12 +118,6 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) { } func TestOIDCProvider_EnrichSession(t *testing.T) { - const ( - idToken = "Unchanged ID Token" - accessToken = "Unchanged Access Token" - refreshToken = "Unchanged Refresh Token" - ) - testCases := map[string]struct { ExistingSession *sessions.SessionState EmailClaim string @@ -550,8 +428,6 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) { } func TestOIDCProviderCreateSessionFromToken(t *testing.T) { - const profileURLEmail = "janed@me.com" - testCases := map[string]struct { IDToken idTokenClaims GroupsClaim string @@ -562,36 +438,35 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) { "Default IDToken": { IDToken: defaultIDToken, GroupsClaim: "groups", - ExpectedUser: defaultIDToken.Subject, - ExpectedEmail: defaultIDToken.Email, + ExpectedUser: "123456789", + ExpectedEmail: "janed@me.com", ExpectedGroups: []string{"test:a", "test:b"}, }, "Minimal IDToken with no email claim": { IDToken: minimalIDToken, GroupsClaim: "groups", - ExpectedUser: minimalIDToken.Subject, - ExpectedEmail: minimalIDToken.Subject, + ExpectedUser: "123456789", + ExpectedEmail: "123456789", ExpectedGroups: []string{}, }, "Custom Groups Claim": { IDToken: defaultIDToken, - GroupsClaim: "other_groups", - ExpectedUser: defaultIDToken.Subject, - ExpectedEmail: defaultIDToken.Email, + GroupsClaim: "roles", + ExpectedUser: "123456789", + ExpectedEmail: "janed@me.com", ExpectedGroups: []string{"test:c", "test:d"}, }, - "Custom Groups Claim2": { - IDToken: customGroupClaimIDToken, + "Complex Groups Claim": { + IDToken: complexGroupsIDToken, GroupsClaim: "groups", - ExpectedUser: customGroupClaimIDToken.Subject, - ExpectedEmail: customGroupClaimIDToken.Email, + ExpectedUser: "123456789", + ExpectedEmail: "complex@claims.com", ExpectedGroups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, }, } for testName, tc := range testCases { t.Run(testName, func(t *testing.T) { - jsonResp := []byte(fmt.Sprintf(`{"email":"%s"}`, profileURLEmail)) - server, provider := newTestSetup(jsonResp) + server, provider := newTestSetup([]byte(`{}`)) provider.GroupsClaim = tc.GroupsClaim defer server.Close() @@ -610,40 +485,3 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) { }) } } - -func TestOIDCProvider_findVerifiedIDToken(t *testing.T) { - - server, provider := newTestSetup([]byte("")) - - defer server.Close() - - token := newOauth2Token() - signedIDToken, _ := newSignedTestIDToken(defaultIDToken) - tokenWithIDToken := token.WithExtra(map[string]interface{}{ - "id_token": signedIDToken, - }) - - verifiedIDToken, err := provider.findVerifiedIDToken(context.Background(), tokenWithIDToken) - assert.Equal(t, true, err == nil) - if verifiedIDToken == nil { - t.Fatal("verifiedIDToken is nil") - } - assert.Equal(t, defaultIDToken.Issuer, verifiedIDToken.Issuer) - assert.Equal(t, defaultIDToken.Subject, verifiedIDToken.Subject) - - // When the validation fails the response should be nil - defaultIDToken.Id = "this-id-fails-validation" - signedIDToken, _ = newSignedTestIDToken(defaultIDToken) - tokenWithIDToken = token.WithExtra(map[string]interface{}{ - "id_token": signedIDToken, - }) - - verifiedIDToken, err = provider.findVerifiedIDToken(context.Background(), tokenWithIDToken) - assert.Equal(t, errors.New("failed to verify signature: the validation failed for subject [123456789]"), err) - assert.Equal(t, true, verifiedIDToken == nil) - - // When there is no id token in the oauth token - verifiedIDToken, err = provider.findVerifiedIDToken(context.Background(), newOauth2Token()) - assert.Equal(t, nil, err) - assert.Equal(t, true, verifiedIDToken == nil) -} diff --git a/providers/provider_data.go b/providers/provider_data.go index 330df7ca..09eadd25 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -1,12 +1,18 @@ package providers import ( + "context" "errors" + "fmt" "io/ioutil" "net/url" + "reflect" + "strings" "github.com/coreos/go-oidc" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + "golang.org/x/oauth2" ) // ProviderData contains information required to configure all implementations @@ -27,7 +33,12 @@ type ProviderData struct { ClientSecretFile string Scope string Prompt string - Verifier *oidc.IDTokenVerifier + + // Common OIDC options for any OIDC-based providers to consume + AllowUnverifiedEmail bool + EmailClaim string + GroupsClaim string + Verifier *oidc.IDTokenVerifier // Universal Group authorization data structure // any provider can set to consume @@ -94,3 +105,99 @@ func defaultURL(u *url.URL, d *url.URL) *url.URL { } return &url.URL{} } + +// **************************************************************************** +// These private OIDC helper methods are available to any providers that are +// OIDC compliant +// **************************************************************************** + +// OIDCClaims is a struct to unmarshal the OIDC claims from an ID Token payload +type OIDCClaims struct { + Subject string `json:"sub"` + Email string `json:"-"` + Groups []string `json:"-"` + Verified *bool `json:"email_verified"` + + raw map[string]interface{} +} + +func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { + rawIDToken := getIDToken(token) + if strings.TrimSpace(rawIDToken) != "" { + return p.Verifier.Verify(ctx, rawIDToken) + } + return nil, nil +} + +// buildSessionFromClaims uses IDToken claims to populate a fresh SessionState +// with non-Token related fields. +func (p *ProviderData) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions.SessionState, error) { + ss := &sessions.SessionState{} + + if idToken == nil { + return ss, nil + } + + claims, err := p.getClaims(idToken) + if err != nil { + return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err) + } + + ss.User = claims.Subject + ss.Email = claims.Email + ss.Groups = claims.Groups + + // TODO (@NickMeves) Deprecate for dynamic claim to session mapping + if pref, ok := claims.raw["preferred_username"].(string); ok { + ss.PreferredUsername = pref + } + + // `email_verified` must be present and explicitly set to `false` to be + // considered unverified. + verifyEmail := (p.EmailClaim == emailClaim) && !p.AllowUnverifiedEmail + if verifyEmail && claims.Verified != nil && !*claims.Verified { + return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) + } + + return ss, nil +} + +// getClaims extracts IDToken claims into an OIDCClaims +func (p *ProviderData) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) { + claims := &OIDCClaims{} + + // Extract default claims. + if err := idToken.Claims(&claims); err != nil { + return nil, fmt.Errorf("failed to parse default id_token claims: %v", err) + } + // Extract custom claims. + if err := idToken.Claims(&claims.raw); err != nil { + return nil, fmt.Errorf("failed to parse all id_token claims: %v", err) + } + + email := claims.raw[p.EmailClaim] + if email != nil { + claims.Email = fmt.Sprint(email) + } + claims.Groups = p.extractGroups(claims.raw) + + return claims, nil +} + +// extractGroups extracts groups from a claim to a list in a type safe manner +func (p *ProviderData) extractGroups(claims map[string]interface{}) []string { + groups := []string{} + rawGroups, ok := claims[p.GroupsClaim].([]interface{}) + if rawGroups != nil && ok { + for _, rawGroup := range rawGroups { + formattedGroup, err := formatGroup(rawGroup) + if err != nil { + logger.Errorf("Warning: unable to format group of type %s with error %s", + reflect.TypeOf(rawGroup), err) + continue + } + groups = append(groups, formattedGroup) + } + } + return groups +} diff --git a/providers/provider_data_test.go b/providers/provider_data_test.go new file mode 100644 index 00000000..4aed73eb --- /dev/null +++ b/providers/provider_data_test.go @@ -0,0 +1,414 @@ +package providers + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strings" + "testing" + "time" + + "github.com/coreos/go-oidc" + "github.com/dgrijalva/jwt-go" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + . "github.com/onsi/gomega" + "golang.org/x/oauth2" +) + +const ( + idToken = "eyJfoobar123.eyJbaz987.IDToken" + accessToken = "eyJfoobar123.eyJbaz987.AccessToken" + refreshToken = "eyJfoobar123.eyJbaz987.RefreshToken" + + oidcIssuer = "https://issuer.example.com" + oidcClientID = "https://test.myapp.com" + oidcSecret = "SuperSecret123456789" + + failureTokenID = "this-id-fails-verification" +) + +var ( + verified = true + unverified = false + + standardClaims = jwt.StandardClaims{ + Audience: oidcClientID, + ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(), + Id: "id-some-id", + IssuedAt: time.Now().Unix(), + Issuer: oidcIssuer, + NotBefore: 0, + Subject: "123456789", + } + + defaultIDToken = idTokenClaims{ + Name: "Jane Dobbs", + Email: "janed@me.com", + Phone: "+4798765432", + Picture: "http://mugbook.com/janed/me.jpg", + Groups: []string{"test:a", "test:b"}, + Roles: []string{"test:c", "test:d"}, + Verified: &verified, + StandardClaims: standardClaims, + } + + complexGroupsIDToken = idTokenClaims{ + Name: "Complex Claim", + Email: "complex@claims.com", + Phone: "+5439871234", + Picture: "http://mugbook.com/complex/claims.jpg", + Groups: []map[string]interface{}{ + { + "groupId": "Admin Group Id", + "roles": []string{"Admin"}, + }, + }, + Roles: []string{"test:simple", "test:roles"}, + Verified: &verified, + StandardClaims: standardClaims, + } + + unverifiedIDToken = idTokenClaims{ + Name: "Mystery Man", + Email: "unverified@email.com", + Phone: "+4025205729", + Picture: "http://mugbook.com/unverified/email.jpg", + Groups: []string{"test:a", "test:b"}, + Roles: []string{"test:c", "test:d"}, + Verified: &unverified, + StandardClaims: standardClaims, + } + + minimalIDToken = idTokenClaims{ + StandardClaims: standardClaims, + } +) + +type idTokenClaims struct { + Name string `json:"preferred_username,omitempty"` + Email string `json:"email,omitempty"` + Phone string `json:"phone_number,omitempty"` + Picture string `json:"picture,omitempty"` + Groups interface{} `json:"groups,omitempty"` + Roles interface{} `json:"roles,omitempty"` + Verified *bool `json:"email_verified,omitempty"` + jwt.StandardClaims +} + +type mockJWKS struct{} + +func (mockJWKS) VerifySignature(_ context.Context, jwt string) ([]byte, error) { + decoded, err := base64.RawURLEncoding.DecodeString(strings.Split(jwt, ".")[1]) + if err != nil { + return nil, err + } + + tokenClaims := &idTokenClaims{} + err = json.Unmarshal(decoded, tokenClaims) + if err != nil || tokenClaims.Id == failureTokenID { + return nil, fmt.Errorf("the validation failed for subject [%v]", tokenClaims.Subject) + } + + return decoded, nil +} + +func newSignedTestIDToken(tokenClaims idTokenClaims) (string, error) { + key, _ := rsa.GenerateKey(rand.Reader, 2048) + standardClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, tokenClaims) + return standardClaims.SignedString(key) +} + +func newTestOauth2Token() *oauth2.Token { + return &oauth2.Token{ + AccessToken: accessToken, + TokenType: "Bearer", + RefreshToken: refreshToken, + Expiry: time.Time{}.Add(time.Duration(5) * time.Second), + } +} + +func TestProviderData_verifyIDToken(t *testing.T) { + failureIDToken := defaultIDToken + failureIDToken.Id = failureTokenID + + testCases := map[string]struct { + IDToken *idTokenClaims + ExpectIDToken bool + ExpectedError error + }{ + "Valid ID Token": { + IDToken: &defaultIDToken, + ExpectIDToken: true, + ExpectedError: nil, + }, + "Invalid ID Token": { + IDToken: &failureIDToken, + ExpectIDToken: false, + ExpectedError: errors.New("failed to verify signature: the validation failed for subject [123456789]"), + }, + "Missing ID Token": { + IDToken: nil, + ExpectIDToken: false, + ExpectedError: nil, + }, + } + + for testName, tc := range testCases { + t.Run(testName, func(t *testing.T) { + g := NewWithT(t) + + token := newTestOauth2Token() + if tc.IDToken != nil { + idToken, err := newSignedTestIDToken(*tc.IDToken) + g.Expect(err).ToNot(HaveOccurred()) + token = token.WithExtra(map[string]interface{}{ + "id_token": idToken, + }) + } + + provider := &ProviderData{ + Verifier: oidc.NewVerifier( + oidcIssuer, + mockJWKS{}, + &oidc.Config{ClientID: oidcClientID}, + ), + } + verified, err := provider.verifyIDToken(context.Background(), token) + if err != nil { + g.Expect(err).To(Equal(tc.ExpectedError)) + } + + if tc.ExpectIDToken { + g.Expect(verified).ToNot(BeNil()) + g.Expect(*verified).To(BeAssignableToTypeOf(oidc.IDToken{})) + } else { + g.Expect(verified).To(BeNil()) + } + }) + } +} + +func TestProviderData_buildSessionFromClaims(t *testing.T) { + testCases := map[string]struct { + IDToken idTokenClaims + AllowUnverified bool + EmailClaim string + GroupsClaim string + ExpectedError error + ExpectedSession *sessions.SessionState + }{ + "Standard": { + IDToken: defaultIDToken, + AllowUnverified: false, + EmailClaim: "email", + GroupsClaim: "groups", + ExpectedSession: &sessions.SessionState{ + User: "123456789", + Email: "janed@me.com", + Groups: []string{"test:a", "test:b"}, + PreferredUsername: "Jane Dobbs", + }, + }, + "Unverified Denied": { + IDToken: unverifiedIDToken, + AllowUnverified: false, + EmailClaim: "email", + GroupsClaim: "groups", + ExpectedError: errors.New("email in id_token (unverified@email.com) isn't verified"), + }, + "Unverified Allowed": { + IDToken: unverifiedIDToken, + AllowUnverified: true, + EmailClaim: "email", + GroupsClaim: "groups", + ExpectedSession: &sessions.SessionState{ + User: "123456789", + Email: "unverified@email.com", + Groups: []string{"test:a", "test:b"}, + PreferredUsername: "Mystery Man", + }, + }, + "Complex Groups": { + IDToken: complexGroupsIDToken, + AllowUnverified: true, + EmailClaim: "email", + GroupsClaim: "groups", + ExpectedSession: &sessions.SessionState{ + User: "123456789", + Email: "complex@claims.com", + Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, + PreferredUsername: "Complex Claim", + }, + }, + "Email Claim Switched": { + IDToken: unverifiedIDToken, + AllowUnverified: true, + EmailClaim: "phone_number", + GroupsClaim: "groups", + ExpectedSession: &sessions.SessionState{ + User: "123456789", + Email: "+4025205729", + Groups: []string{"test:a", "test:b"}, + PreferredUsername: "Mystery Man", + }, + }, + "Email Claim Switched to Non String": { + IDToken: unverifiedIDToken, + AllowUnverified: true, + EmailClaim: "roles", + GroupsClaim: "groups", + ExpectedSession: &sessions.SessionState{ + User: "123456789", + Email: "[test:c test:d]", + Groups: []string{"test:a", "test:b"}, + PreferredUsername: "Mystery Man", + }, + }, + "Email Claim Non Existent": { + IDToken: unverifiedIDToken, + AllowUnverified: true, + EmailClaim: "aksjdfhjksadh", + GroupsClaim: "groups", + ExpectedSession: &sessions.SessionState{ + User: "123456789", + Email: "", + Groups: []string{"test:a", "test:b"}, + PreferredUsername: "Mystery Man", + }, + }, + "Groups Claim Switched": { + IDToken: defaultIDToken, + AllowUnverified: false, + EmailClaim: "email", + GroupsClaim: "roles", + ExpectedSession: &sessions.SessionState{ + User: "123456789", + Email: "janed@me.com", + Groups: []string{"test:c", "test:d"}, + PreferredUsername: "Jane Dobbs", + }, + }, + "Groups Claim Non Existent": { + IDToken: defaultIDToken, + AllowUnverified: false, + EmailClaim: "email", + GroupsClaim: "alskdjfsalkdjf", + ExpectedSession: &sessions.SessionState{ + User: "123456789", + Email: "janed@me.com", + Groups: []string{}, + PreferredUsername: "Jane Dobbs", + }, + }, + } + for testName, tc := range testCases { + t.Run(testName, func(t *testing.T) { + g := NewWithT(t) + + provider := &ProviderData{ + Verifier: oidc.NewVerifier( + oidcIssuer, + mockJWKS{}, + &oidc.Config{ClientID: oidcClientID}, + ), + } + provider.AllowUnverifiedEmail = tc.AllowUnverified + provider.EmailClaim = tc.EmailClaim + provider.GroupsClaim = tc.GroupsClaim + + rawIDToken, err := newSignedTestIDToken(tc.IDToken) + g.Expect(err).ToNot(HaveOccurred()) + + idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken) + g.Expect(err).ToNot(HaveOccurred()) + + ss, err := provider.buildSessionFromClaims(idToken) + if err != nil { + g.Expect(err).To(Equal(tc.ExpectedError)) + } + if ss != nil { + g.Expect(ss).To(Equal(tc.ExpectedSession)) + } + }) + } +} + +func TestProviderData_extractGroups(t *testing.T) { + testCases := map[string]struct { + Claims map[string]interface{} + GroupsClaim string + ExpectedGroups []string + }{ + "Standard String Groups": { + Claims: map[string]interface{}{ + "email": "this@does.not.matter.com", + "groups": []interface{}{"three", "string", "groups"}, + }, + GroupsClaim: "groups", + ExpectedGroups: []string{"three", "string", "groups"}, + }, + "Different Claim Name": { + Claims: map[string]interface{}{ + "email": "this@does.not.matter.com", + "roles": []interface{}{"three", "string", "roles"}, + }, + GroupsClaim: "roles", + ExpectedGroups: []string{"three", "string", "roles"}, + }, + "Numeric Groups": { + Claims: map[string]interface{}{ + "email": "this@does.not.matter.com", + "groups": []interface{}{1, 2, 3}, + }, + GroupsClaim: "groups", + ExpectedGroups: []string{"1", "2", "3"}, + }, + "Complex Groups": { + Claims: map[string]interface{}{ + "email": "this@does.not.matter.com", + "groups": []interface{}{ + map[string]interface{}{ + "groupId": "Admin Group Id", + "roles": []string{"Admin"}, + }, + 12345, + "Just::A::String", + }, + }, + GroupsClaim: "groups", + ExpectedGroups: []string{ + "{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}", + "12345", + "Just::A::String", + }, + }, + "Missing Groups": { + Claims: map[string]interface{}{ + "email": "this@does.not.matter.com", + }, + GroupsClaim: "groups", + ExpectedGroups: []string{}, + }, + } + for testName, tc := range testCases { + t.Run(testName, func(t *testing.T) { + g := NewWithT(t) + + provider := &ProviderData{ + Verifier: oidc.NewVerifier( + oidcIssuer, + mockJWKS{}, + &oidc.Config{ClientID: oidcClientID}, + ), + } + provider.GroupsClaim = tc.GroupsClaim + + groups := provider.extractGroups(tc.Claims) + g.Expect(groups).To(Equal(tc.ExpectedGroups)) + }) + } +} diff --git a/providers/provider_default.go b/providers/provider_default.go index 012a538c..69dddb06 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -13,6 +13,8 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" ) +const emailClaim = "email" + var ( // ErrNotImplemented is returned when a provider did not override a default // implementation method that doesn't have sensible defaults From eb56f24d6dba6170fde7ec74731e105ac57d947f Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 28 Nov 2020 12:33:05 -0800 Subject: [PATCH 03/28] Deprecate UserIDClaim in config and docs --- CHANGELOG.md | 2 ++ docs/docs/configuration/overview.md | 4 ++-- pkg/apis/options/options.go | 11 +++++++---- pkg/validation/options.go | 12 ++++++++---- providers/oidc.go | 2 +- providers/provider_data.go | 7 ++++++- providers/provider_default.go | 2 -- 7 files changed, 26 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f6065b4..74c272eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ## Important Notes +- [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) `--user-id-claim` option is deprecated and replaced by `--oidc-email-claim` - [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Gitlab projects needs a Gitlab application with the extra `read_api` enabled - [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication. - [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped. @@ -47,6 +48,7 @@ - [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Add support for Gitlab project based authentication (@factorysh) - [#907](https://github.com/oauth2-proxy/oauth2-proxy/pull/907) Introduce alpha configuration option to enable testing of structured configuration (@JoelSpeed) - [#938](https://github.com/oauth2-proxy/oauth2-proxy/pull/938) Cleanup missed provider renaming refactor methods (@NickMeves) +- [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) Refactor OIDC Provider and support groups from Profile URL (@NickMeves) - [#925](https://github.com/oauth2-proxy/oauth2-proxy/pull/925) Fix basic auth legacy header conversion (@JoelSpeed) - [#916](https://github.com/oauth2-proxy/oauth2-proxy/pull/916) Add AlphaOptions struct to prepare for alpha config loading (@JoelSpeed) - [#923](https://github.com/oauth2-proxy/oauth2-proxy/pull/923) Support TLS 1.3 (@aajisaka) diff --git a/docs/docs/configuration/overview.md b/docs/docs/configuration/overview.md index 46e72d67..a91c6264 100644 --- a/docs/docs/configuration/overview.md +++ b/docs/docs/configuration/overview.md @@ -74,7 +74,8 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/ | `--insecure-oidc-skip-issuer-verification` | bool | allow the OIDC issuer URL to differ from the expected (currently required for Azure multi-tenant compatibility) | false | | `--oidc-issuer-url` | string | the OpenID Connect issuer URL, e.g. `"https://accounts.google.com"` | | | `--oidc-jwks-url` | string | OIDC JWKS URI for token verification; required if OIDC discovery is disabled | | -| `--oidc-groups-claim` | string | which claim contains the user groups | `"groups"` | +| `--oidc-email-claim` | string | which OIDC claim contains the user's email | `"email"` | +| `--oidc-groups-claim` | string | which OIDC claim contains the user groups | `"groups"` | | `--pass-access-token` | bool | pass OAuth access_token to upstream via X-Forwarded-Access-Token header. When used with `--set-xauthrequest` this adds the X-Auth-Request-Access-Token header to the response | false | | `--pass-authorization-header` | bool | pass OIDC IDToken to upstream via Authorization Bearer header | false | | `--pass-basic-auth` | bool | pass HTTP Basic Auth, X-Forwarded-User, X-Forwarded-Email and X-Forwarded-Preferred-Username information to upstream | true | @@ -128,7 +129,6 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/ | `--tls-cert-file` | string | path to certificate file | | | `--tls-key-file` | string | path to private key file | | | `--upstream` | string \| list | the http url(s) of the upstream endpoint, file:// paths for static files or `static://` for static response. Routing is based on the path | | -| `--user-id-claim` | string | which claim contains the user ID | \["email"\] | | `--allowed-group` | string \| list | restrict logins to members of this group (may be given multiple times) | | | `--validate-url` | string | Access token validation endpoint | | | `--version` | n/a | print version string | | diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index cf4e2414..46cdcedb 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -87,6 +87,7 @@ type Options struct { InsecureOIDCSkipIssuerVerification bool `flag:"insecure-oidc-skip-issuer-verification" cfg:"insecure_oidc_skip_issuer_verification"` SkipOIDCDiscovery bool `flag:"skip-oidc-discovery" cfg:"skip_oidc_discovery"` OIDCJwksURL string `flag:"oidc-jwks-url" cfg:"oidc_jwks_url"` + OIDCEmailClaim string `flag:"oidc-email-claim" cfg:"oidc_email_claim"` OIDCGroupsClaim string `flag:"oidc-groups-claim" cfg:"oidc_groups_claim"` LoginURL string `flag:"login-url" cfg:"login_url"` RedeemURL string `flag:"redeem-url" cfg:"redeem_url"` @@ -148,11 +149,12 @@ func NewOptions() *Options { SkipAuthPreflight: false, Prompt: "", // Change to "login" when ApprovalPrompt officially deprecated ApprovalPrompt: "force", - UserIDClaim: "email", InsecureOIDCAllowUnverifiedEmail: false, SkipOIDCDiscovery: false, Logging: loggingDefaults(), - OIDCGroupsClaim: "groups", + UserIDClaim: providers.OIDCEmailClaim, // Deprecated: Use OIDCEmailClaim + OIDCEmailClaim: providers.OIDCEmailClaim, + OIDCGroupsClaim: providers.OIDCGroupsClaim, } } @@ -226,7 +228,8 @@ func NewFlagSet() *pflag.FlagSet { flagSet.Bool("insecure-oidc-skip-issuer-verification", false, "Do not verify if issuer matches OIDC discovery URL") flagSet.Bool("skip-oidc-discovery", false, "Skip OIDC discovery and use manually supplied Endpoints") flagSet.String("oidc-jwks-url", "", "OpenID Connect JWKS URL (ie: https://www.googleapis.com/oauth2/v3/certs)") - flagSet.String("oidc-groups-claim", "groups", "which claim contains the user groups") + flagSet.String("oidc-groups-claim", providers.OIDCGroupsClaim, "which OIDC claim contains the user groups") + flagSet.String("oidc-email-claim", providers.OIDCEmailClaim, "which OIDC claim contains the user's email") flagSet.String("login-url", "", "Authentication endpoint") flagSet.String("redeem-url", "", "Token redemption endpoint") flagSet.String("profile-url", "", "Profile access endpoint") @@ -243,7 +246,7 @@ func NewFlagSet() *pflag.FlagSet { flagSet.String("pubjwk-url", "", "JWK pubkey access endpoint: required by login.gov") flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints") - flagSet.String("user-id-claim", "email", "which claim contains the user ID") + flagSet.String("user-id-claim", providers.OIDCEmailClaim, "(DEPRECATED for `oidc-email-claim`) which claim contains the user ID") flagSet.StringSlice("allowed-group", []string{}, "restrict logins to members of this group (may be given multiple times)") flagSet.AddFlagSet(cookieFlagSet()) diff --git a/pkg/validation/options.go b/pkg/validation/options.go index 652ada9e..d5e58312 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -235,10 +235,17 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { // Make the OIDC options available to all providers that support it p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail - p.EmailClaim = o.UserIDClaim + p.EmailClaim = o.OIDCEmailClaim p.GroupsClaim = o.OIDCGroupsClaim p.Verifier = o.GetOIDCVerifier() + // TODO (@NickMeves) - Remove This + // Backwards Compatibility for Deprecated UserIDClaim option + if o.OIDCEmailClaim == providers.OIDCEmailClaim && + o.UserIDClaim != providers.OIDCEmailClaim { + p.EmailClaim = o.UserIDClaim + } + p.SetAllowedGroups(o.AllowedGroups) provider := providers.New(o.ProviderType, p) @@ -276,9 +283,6 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { p.SetTeam(o.BitbucketTeam) p.SetRepository(o.BitbucketRepository) case *providers.OIDCProvider: - p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail - p.EmailClaim = o.UserIDClaim - p.GroupsClaim = o.OIDCGroupsClaim if p.Verifier == nil { msgs = append(msgs, "oidc provider requires an oidc issuer URL") } diff --git a/providers/oidc.go b/providers/oidc.go index f90348d6..d7d34700 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -42,7 +42,7 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*s } token, err := c.Exchange(ctx, code) if err != nil { - return nil, fmt.Errorf("token exchange failure: %v", err) + return nil, fmt.Errorf("token exchange failed: %v", err) } return p.createSession(ctx, token, false) diff --git a/providers/provider_data.go b/providers/provider_data.go index 09eadd25..ae434515 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -15,6 +15,11 @@ import ( "golang.org/x/oauth2" ) +const ( + OIDCEmailClaim = "email" + OIDCGroupsClaim = "groups" +) + // ProviderData contains information required to configure all implementations // of OAuth2 providers type ProviderData struct { @@ -154,7 +159,7 @@ func (p *ProviderData) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions. // `email_verified` must be present and explicitly set to `false` to be // considered unverified. - verifyEmail := (p.EmailClaim == emailClaim) && !p.AllowUnverifiedEmail + verifyEmail := (p.EmailClaim == OIDCEmailClaim) && !p.AllowUnverifiedEmail if verifyEmail && claims.Verified != nil && !*claims.Verified { return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) } diff --git a/providers/provider_default.go b/providers/provider_default.go index 69dddb06..012a538c 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -13,8 +13,6 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" ) -const emailClaim = "email" - var ( // ErrNotImplemented is returned when a provider did not override a default // implementation method that doesn't have sensible defaults From ea5b8cc21fb6c8feac6226b1131c45e9a173df2f Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sun, 29 Nov 2020 14:58:01 -0800 Subject: [PATCH 04/28] Support non-list and complex groups --- CHANGELOG.md | 1 + providers/oidc.go | 18 +++--- providers/oidc_test.go | 107 ++++++++++++++++++++++++++++---- providers/provider_data.go | 36 +++++++---- providers/provider_data_test.go | 20 ++++-- providers/util.go | 20 ++++++ 6 files changed, 166 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 74c272eb..84f54dec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ - [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Add support for Gitlab project based authentication (@factorysh) - [#907](https://github.com/oauth2-proxy/oauth2-proxy/pull/907) Introduce alpha configuration option to enable testing of structured configuration (@JoelSpeed) - [#938](https://github.com/oauth2-proxy/oauth2-proxy/pull/938) Cleanup missed provider renaming refactor methods (@NickMeves) +- [#816](https://github.com/oauth2-proxy/oauth2-proxy/pull/816) (via [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936)) Support non-list group claims (@loafoe) - [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) Refactor OIDC Provider and support groups from Profile URL (@NickMeves) - [#925](https://github.com/oauth2-proxy/oauth2-proxy/pull/925) Fix basic auth legacy header conversion (@JoelSpeed) - [#916](https://github.com/oauth2-proxy/oauth2-proxy/pull/916) Add AlphaOptions struct to prepare for alpha config loading (@JoelSpeed) diff --git a/providers/oidc.go b/providers/oidc.go index d7d34700..98cefb4b 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "reflect" "time" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" @@ -59,7 +60,7 @@ func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionSta } // Try to get missing emails or groups from a profileURL - if s.Email == "" || len(s.Groups) == 0 { + if s.Email == "" || s.Groups == nil { err := p.callProfileURL(ctx, s) if err != nil { logger.Errorf("Warning: Profile URL request failed: %v", err) @@ -90,16 +91,15 @@ func (p *OIDCProvider) callProfileURL(ctx context.Context, s *sessions.SessionSt s.Email = email } - // Handle array & singleton groups cases if len(s.Groups) == 0 { - groups, err := respJSON.Get(p.GroupsClaim).StringArray() - if err == nil { - s.Groups = groups - } else { - group, err := respJSON.Get(p.GroupsClaim).String() - if err == nil { - s.Groups = []string{group} + for _, group := range coerceArray(respJSON, p.GroupsClaim) { + formatted, err := formatGroup(group) + if err != nil { + logger.Errorf("Warning: unable to format group of type %s with error %s", + reflect.TypeOf(group), err) + continue } + s.Groups = append(s.Groups, formatted) } } diff --git a/providers/oidc_test.go b/providers/oidc_test.go index 2651b4ea..7ac98634 100644 --- a/providers/oidc_test.go +++ b/providers/oidc_test.go @@ -68,7 +68,7 @@ func newOIDCServer(body []byte) (*url.URL, *httptest.Server) { return u, s } -func newTestSetup(body []byte) (*httptest.Server, *OIDCProvider) { +func newTestOIDCSetup(body []byte) (*httptest.Server, *OIDCProvider) { redeemURL, server := newOIDCServer(body) provider := newOIDCProvider(redeemURL) return server, provider @@ -85,7 +85,7 @@ func TestOIDCProviderRedeem(t *testing.T) { IDToken: idToken, }) - server, provider := newTestSetup(body) + server, provider := newTestOIDCSetup(body) defer server.Close() session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234") @@ -108,7 +108,7 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) { IDToken: idToken, }) - server, provider := newTestSetup(body) + server, provider := newTestOIDCSetup(body) provider.EmailClaim = "phone_number" defer server.Close() @@ -247,7 +247,7 @@ func TestOIDCProvider_EnrichSession(t *testing.T) { ExistingSession: &sessions.SessionState{ User: "already", Email: "already@populated.com", - Groups: []string{}, + Groups: nil, IDToken: idToken, AccessToken: accessToken, RefreshToken: refreshToken, @@ -268,6 +268,89 @@ func TestOIDCProvider_EnrichSession(t *testing.T) { RefreshToken: refreshToken, }, }, + "Missing Groups with Complex Groups in Profile URL": { + ExistingSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: nil, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "new@thing.com", + "groups": []map[string]interface{}{ + { + "groupId": "Admin Group Id", + "roles": []string{"Admin"}, + }, + }, + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Missing Groups with Singleton Complex Group in Profile URL": { + ExistingSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: nil, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "new@thing.com", + "groups": map[string]interface{}{ + "groupId": "Admin Group Id", + "roles": []string{"Admin"}, + }, + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, + "Empty Groups Claims": { + ExistingSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + EmailClaim: "email", + GroupsClaim: "groups", + ProfileJSON: map[string]interface{}{ + "email": "new@thing.com", + "groups": []string{"new", "thing"}, + }, + ExpectedError: nil, + ExpectedSession: &sessions.SessionState{ + User: "already", + Email: "already@populated.com", + Groups: []string{}, + IDToken: idToken, + AccessToken: accessToken, + RefreshToken: refreshToken, + }, + }, "Missing Groups with Custom Claim": { ExistingSession: &sessions.SessionState{ User: "already", @@ -297,7 +380,7 @@ func TestOIDCProvider_EnrichSession(t *testing.T) { ExistingSession: &sessions.SessionState{ User: "already", Email: "already@populated.com", - Groups: []string{}, + Groups: nil, IDToken: idToken, AccessToken: accessToken, RefreshToken: refreshToken, @@ -346,7 +429,7 @@ func TestOIDCProvider_EnrichSession(t *testing.T) { jsonResp, err := json.Marshal(tc.ProfileJSON) assert.NoError(t, err) - server, provider := newTestSetup(jsonResp) + server, provider := newTestOIDCSetup(jsonResp) provider.ProfileURL, err = url.Parse(server.URL) assert.NoError(t, err) @@ -371,7 +454,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { RefreshToken: refreshToken, }) - server, provider := newTestSetup(body) + server, provider := newTestOIDCSetup(body) defer server.Close() existingSession := &sessions.SessionState{ @@ -405,7 +488,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) { IDToken: idToken, }) - server, provider := newTestSetup(body) + server, provider := newTestOIDCSetup(body) defer server.Close() existingSession := &sessions.SessionState{ @@ -433,7 +516,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) { GroupsClaim string ExpectedUser string ExpectedEmail string - ExpectedGroups interface{} + ExpectedGroups []string }{ "Default IDToken": { IDToken: defaultIDToken, @@ -447,7 +530,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) { GroupsClaim: "groups", ExpectedUser: "123456789", ExpectedEmail: "123456789", - ExpectedGroups: []string{}, + ExpectedGroups: nil, }, "Custom Groups Claim": { IDToken: defaultIDToken, @@ -466,7 +549,7 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) { } for testName, tc := range testCases { t.Run(testName, func(t *testing.T) { - server, provider := newTestSetup([]byte(`{}`)) + server, provider := newTestOIDCSetup([]byte(`{}`)) provider.GroupsClaim = tc.GroupsClaim defer server.Close() @@ -478,9 +561,9 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) { assert.Equal(t, tc.ExpectedUser, ss.User) assert.Equal(t, tc.ExpectedEmail, ss.Email) + assert.Equal(t, tc.ExpectedGroups, ss.Groups) assert.Equal(t, rawIDToken, ss.IDToken) assert.Equal(t, rawIDToken, ss.AccessToken) - assert.Equal(t, tc.ExpectedGroups, ss.Groups) assert.Equal(t, "", ss.RefreshToken) }) } diff --git a/providers/provider_data.go b/providers/provider_data.go index ae434515..098e6192 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -189,20 +189,34 @@ func (p *ProviderData) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) { return claims, nil } -// extractGroups extracts groups from a claim to a list in a type safe manner +// extractGroups extracts groups from a claim to a list in a type safe manner. +// If the claim isn't present, `nil` is returned. If the groups claim is +// present but empty, `[]string{}` is returned. func (p *ProviderData) extractGroups(claims map[string]interface{}) []string { + rawClaim, ok := claims[p.GroupsClaim] + if !ok { + return nil + } + + // Handle traditional list-based groups as well as non-standard singleton + // based groups. Both variants support complex objects if needed. + var claimGroups []interface{} + switch raw := rawClaim.(type) { + case []interface{}: + claimGroups = raw + case interface{}: + claimGroups = []interface{}{raw} + } + groups := []string{} - rawGroups, ok := claims[p.GroupsClaim].([]interface{}) - if rawGroups != nil && ok { - for _, rawGroup := range rawGroups { - formattedGroup, err := formatGroup(rawGroup) - if err != nil { - logger.Errorf("Warning: unable to format group of type %s with error %s", - reflect.TypeOf(rawGroup), err) - continue - } - groups = append(groups, formattedGroup) + for _, rawGroup := range claimGroups { + formattedGroup, err := formatGroup(rawGroup) + if err != nil { + logger.Errorf("Warning: unable to format group of type %s with error %s", + reflect.TypeOf(rawGroup), err) + continue } + groups = append(groups, formattedGroup) } return groups } diff --git a/providers/provider_data_test.go b/providers/provider_data_test.go index 4aed73eb..f94c0db1 100644 --- a/providers/provider_data_test.go +++ b/providers/provider_data_test.go @@ -300,7 +300,7 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) { ExpectedSession: &sessions.SessionState{ User: "123456789", Email: "janed@me.com", - Groups: []string{}, + Groups: nil, PreferredUsername: "Jane Dobbs", }, }, @@ -386,12 +386,20 @@ func TestProviderData_extractGroups(t *testing.T) { "Just::A::String", }, }, - "Missing Groups": { + "Missing Groups Claim Returns Nil": { Claims: map[string]interface{}{ "email": "this@does.not.matter.com", }, GroupsClaim: "groups", - ExpectedGroups: []string{}, + ExpectedGroups: nil, + }, + "Non List Groups": { + Claims: map[string]interface{}{ + "email": "this@does.not.matter.com", + "groups": "singleton", + }, + GroupsClaim: "groups", + ExpectedGroups: []string{"singleton"}, }, } for testName, tc := range testCases { @@ -408,7 +416,11 @@ func TestProviderData_extractGroups(t *testing.T) { provider.GroupsClaim = tc.GroupsClaim groups := provider.extractGroups(tc.Claims) - g.Expect(groups).To(Equal(tc.ExpectedGroups)) + if tc.ExpectedGroups != nil { + g.Expect(groups).To(Equal(tc.ExpectedGroups)) + } else { + g.Expect(groups).To(BeNil()) + } }) } } diff --git a/providers/util.go b/providers/util.go index acf20902..055d29db 100644 --- a/providers/util.go +++ b/providers/util.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" + "github.com/bitly/go-simplejson" "golang.org/x/oauth2" ) @@ -59,6 +60,8 @@ func makeLoginURL(p *ProviderData, redirectURI, state string, extraParams url.Va return a } +// getIDToken extracts an IDToken stored in the `Extra` fields of an +// oauth2.Token func getIDToken(token *oauth2.Token) string { idToken, ok := token.Extra("id_token").(string) if !ok { @@ -67,6 +70,8 @@ func getIDToken(token *oauth2.Token) string { return idToken } +// formatGroup coerces an OIDC groups claim into a string +// If it is non-string, marshal it into JSON. func formatGroup(rawGroup interface{}) (string, error) { group, ok := rawGroup.(string) if !ok { @@ -78,3 +83,18 @@ func formatGroup(rawGroup interface{}) (string, error) { } return group, nil } + +// coerceArray extracts a field from simplejson.Json that might be a +// singleton or a list and coerces it into a list. +func coerceArray(sj *simplejson.Json, key string) []interface{} { + array, err := sj.Get(key).Array() + if err == nil { + return array + } + + single := sj.Get(key).Interface() + if single == nil { + return nil + } + return []interface{}{single} +} From 42f6cef7d6571aa0ea5fea109a77177a069a950e Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Tue, 1 Dec 2020 12:01:42 -0800 Subject: [PATCH 05/28] Improve OIDC error handling --- providers/oidc.go | 15 +++++++++------ providers/provider_data.go | 5 ++++- providers/provider_data_test.go | 19 +++++++++++++++---- providers/provider_default.go | 8 ++++++++ 4 files changed, 36 insertions(+), 11 deletions(-) diff --git a/providers/oidc.go b/providers/oidc.go index 98cefb4b..cdeee3b2 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -206,12 +206,15 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) { idToken, err := p.verifyIDToken(ctx, token) if err != nil { - return nil, fmt.Errorf("could not verify id_token: %v", err) - } - - // IDToken is mandatory in Redeem but optional in Refresh - if idToken == nil && !refresh { - return nil, errors.New("token response did not contain an id_token") + switch err { + case ErrMissingIDToken: + // IDToken is mandatory in Redeem but optional in Refresh + if !refresh { + return nil, errors.New("token response did not contain an id_token") + } + default: + return nil, fmt.Errorf("could not verify id_token: %v", err) + } } ss, err := p.buildSessionFromClaims(idToken) diff --git a/providers/provider_data.go b/providers/provider_data.go index 098e6192..d8c9312b 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -129,9 +129,12 @@ type OIDCClaims struct { func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { rawIDToken := getIDToken(token) if strings.TrimSpace(rawIDToken) != "" { + if p.Verifier == nil { + return nil, ErrMissingOIDCVerifier + } return p.Verifier.Verify(ctx, rawIDToken) } - return nil, nil + return nil, ErrMissingIDToken } // buildSessionFromClaims uses IDToken claims to populate a fresh SessionState diff --git a/providers/provider_data_test.go b/providers/provider_data_test.go index f94c0db1..80f6ecab 100644 --- a/providers/provider_data_test.go +++ b/providers/provider_data_test.go @@ -137,23 +137,33 @@ func TestProviderData_verifyIDToken(t *testing.T) { testCases := map[string]struct { IDToken *idTokenClaims + Verifier bool ExpectIDToken bool ExpectedError error }{ "Valid ID Token": { IDToken: &defaultIDToken, + Verifier: true, ExpectIDToken: true, ExpectedError: nil, }, "Invalid ID Token": { IDToken: &failureIDToken, + Verifier: true, ExpectIDToken: false, ExpectedError: errors.New("failed to verify signature: the validation failed for subject [123456789]"), }, "Missing ID Token": { IDToken: nil, + Verifier: true, ExpectIDToken: false, - ExpectedError: nil, + ExpectedError: ErrMissingIDToken, + }, + "OIDC Verifier not Configured": { + IDToken: &defaultIDToken, + Verifier: false, + ExpectIDToken: false, + ExpectedError: ErrMissingOIDCVerifier, }, } @@ -170,12 +180,13 @@ func TestProviderData_verifyIDToken(t *testing.T) { }) } - provider := &ProviderData{ - Verifier: oidc.NewVerifier( + provider := &ProviderData{} + if tc.Verifier { + provider.Verifier = oidc.NewVerifier( oidcIssuer, mockJWKS{}, &oidc.Config{ClientID: oidcClientID}, - ), + ) } verified, err := provider.verifyIDToken(context.Background(), token) if err != nil { diff --git a/providers/provider_default.go b/providers/provider_default.go index 012a538c..d3c6d113 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -22,6 +22,14 @@ var ( // code ErrMissingCode = errors.New("missing code") + // ErrMissingIDToken is returned when an oidc.Token does not contain the + // extra `id_token` field for an IDToken. + ErrMissingIDToken = errors.New("missing id_token") + + // ErrMissingOIDCVerifier is returned when a provider didn't set `Verifier` + // but an attempt to call `Verifier.Verify` was about to be made. + ErrMissingOIDCVerifier = errors.New("oidc verifier is not configured") + _ Provider = (*ProviderData)(nil) ) From d2ffef2c7e0a622295c42cb2fdd70ef14593d03c Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Tue, 1 Dec 2020 17:50:27 -0800 Subject: [PATCH 06/28] Use global OIDC fields for Gitlab --- pkg/validation/options.go | 1 - providers/gitlab.go | 25 +++++++++++-------------- providers/oidc.go | 29 +++++++++++++++-------------- providers/provider_data.go | 12 ++++++------ providers/util.go | 16 ++++++++-------- 5 files changed, 40 insertions(+), 43 deletions(-) diff --git a/pkg/validation/options.go b/pkg/validation/options.go index d5e58312..4fc0b0a4 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -287,7 +287,6 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { msgs = append(msgs, "oidc provider requires an oidc issuer URL") } case *providers.GitLabProvider: - p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail p.Groups = o.GitLabGroup err := p.AddProjects(o.GitlabProjects) if err != nil { diff --git a/providers/gitlab.go b/providers/gitlab.go index 246dd78c..c5922abb 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -20,8 +20,6 @@ type GitLabProvider struct { Groups []string Projects []*GitlabProject - - AllowUnverifiedEmail bool } // GitlabProject represents a Gitlab project constraint entity @@ -103,7 +101,7 @@ func (p *GitLabProvider) Redeem(ctx context.Context, redirectURL, code string) ( if err != nil { return nil, fmt.Errorf("token exchange: %v", err) } - s, err = p.createSessionState(ctx, token) + s, err = p.createSession(ctx, token) if err != nil { return nil, fmt.Errorf("unable to update session: %v", err) } @@ -162,7 +160,7 @@ func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.Ses if err != nil { return fmt.Errorf("failed to get token: %v", err) } - newSession, err := p.createSessionState(ctx, token) + newSession, err := p.createSession(ctx, token) if err != nil { return fmt.Errorf("unable to update session: %v", err) } @@ -255,22 +253,21 @@ func (p *GitLabProvider) AddProjects(projects []string) error { return nil } -func (p *GitLabProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) { - rawIDToken, ok := token.Extra("id_token").(string) - if !ok { - return nil, fmt.Errorf("token response did not contain an id_token") - } - - // Parse and verify ID Token payload. - idToken, err := p.Verifier.Verify(ctx, rawIDToken) +func (p *GitLabProvider) createSession(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) { + idToken, err := p.verifyIDToken(ctx, token) if err != nil { - return nil, fmt.Errorf("could not verify id_token: %v", err) + switch err { + case ErrMissingIDToken: + return nil, fmt.Errorf("token response did not contain an id_token") + default: + return nil, fmt.Errorf("could not verify id_token: %v", err) + } } created := time.Now() return &sessions.SessionState{ AccessToken: token.AccessToken, - IDToken: rawIDToken, + IDToken: getIDToken(token), RefreshToken: token.RefreshToken, CreatedAt: &created, ExpiresOn: &idToken.Expiry, diff --git a/providers/oidc.go b/providers/oidc.go index cdeee3b2..df133f4d 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -49,7 +49,7 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*s return p.createSession(ctx, token, false) } -// EnrichSessionState is called after Redeem to allow providers to enrich session fields +// EnrichSession is called after Redeem to allow providers to enrich session fields // such as User, Email, Groups with provider specific API calls. func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { if p.ProfileURL.String() == "" { @@ -61,7 +61,7 @@ func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionSta // Try to get missing emails or groups from a profileURL if s.Email == "" || s.Groups == nil { - err := p.callProfileURL(ctx, s) + err := p.enrichFromProfileURL(ctx, s) if err != nil { logger.Errorf("Warning: Profile URL request failed: %v", err) } @@ -74,9 +74,9 @@ func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionSta return nil } -// callProfileURL enriches a session's Email & Groups via the JSON response of +// enrichFromProfileURL enriches a session's Email & Groups via the JSON response of // an OIDC profile URL -func (p *OIDCProvider) callProfileURL(ctx context.Context, s *sessions.SessionState) error { +func (p *OIDCProvider) enrichFromProfileURL(ctx context.Context, s *sessions.SessionState) error { respJSON, err := requests.New(p.ProfileURL.String()). WithContext(ctx). WithHeaders(makeOIDCHeader(s.AccessToken)). @@ -91,22 +91,23 @@ func (p *OIDCProvider) callProfileURL(ctx context.Context, s *sessions.SessionSt s.Email = email } - if len(s.Groups) == 0 { - for _, group := range coerceArray(respJSON, p.GroupsClaim) { - formatted, err := formatGroup(group) - if err != nil { - logger.Errorf("Warning: unable to format group of type %s with error %s", - reflect.TypeOf(group), err) - continue - } - s.Groups = append(s.Groups, formatted) + if len(s.Groups) > 0 { + return nil + } + for _, group := range coerceArray(respJSON, p.GroupsClaim) { + formatted, err := formatGroup(group) + if err != nil { + logger.Errorf("Warning: unable to format group of type %s with error %s", + reflect.TypeOf(group), err) + continue } + s.Groups = append(s.Groups, formatted) } return nil } -// ValidateSessionState checks that the session's IDToken is still valid +// ValidateSession checks that the session's IDToken is still valid func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { _, err := p.Verifier.Verify(ctx, s.IDToken) return err == nil diff --git a/providers/provider_data.go b/providers/provider_data.go index d8c9312b..a9a41232 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -128,13 +128,13 @@ type OIDCClaims struct { func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { rawIDToken := getIDToken(token) - if strings.TrimSpace(rawIDToken) != "" { - if p.Verifier == nil { - return nil, ErrMissingOIDCVerifier - } - return p.Verifier.Verify(ctx, rawIDToken) + if strings.TrimSpace(rawIDToken) == "" { + return nil, ErrMissingIDToken } - return nil, ErrMissingIDToken + if p.Verifier == nil { + return nil, ErrMissingOIDCVerifier + } + return p.Verifier.Verify(ctx, rawIDToken) } // buildSessionFromClaims uses IDToken claims to populate a fresh SessionState diff --git a/providers/util.go b/providers/util.go index 055d29db..e6fdc344 100644 --- a/providers/util.go +++ b/providers/util.go @@ -73,15 +73,15 @@ func getIDToken(token *oauth2.Token) string { // formatGroup coerces an OIDC groups claim into a string // If it is non-string, marshal it into JSON. func formatGroup(rawGroup interface{}) (string, error) { - group, ok := rawGroup.(string) - if !ok { - jsonGroup, err := json.Marshal(rawGroup) - if err != nil { - return "", err - } - group = string(jsonGroup) + if group, ok := rawGroup.(string); ok { + return group, nil } - return group, nil + + jsonGroup, err := json.Marshal(rawGroup) + if err != nil { + return "", err + } + return string(jsonGroup), nil } // coerceArray extracts a field from simplejson.Json that might be a From 23b2355f852bf66aaccc47087003a1a3f4578fe7 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sun, 18 Oct 2020 18:14:32 -0700 Subject: [PATCH 07/28] Allow group authZ in AuthOnly endpoint via Querystring --- CHANGELOG.md | 5 ++ oauthproxy.go | 50 ++++++++++++++++++-- oauthproxy_test.go | 115 +++++++++++++++++++++++++++++++++++++-------- 3 files changed, 147 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 84f54dec..ef5a77aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ - [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) `--user-id-claim` option is deprecated and replaced by `--oidc-email-claim` - [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Gitlab projects needs a Gitlab application with the extra `read_api` enabled +- [#849](https://github.com/oauth2-proxy/oauth2-proxy/pull/849) `/oauth2/auth` `allowed_groups` querystring parameter can be paired with the `allowed-groups` configuration option. + - In this scenario, the user's group must be in both lists to not get a 401 response code. + - The `allowed_group` querystring parameter can be specified multiple times to support multiple groups. + - The `allowed_groups` querystring parameter can specify multiple comma delimited groups. - [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication. - [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped. - [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) The behavior of the Google provider Groups restriction changes with this @@ -56,6 +60,7 @@ - [#918](https://github.com/oauth2-proxy/oauth2-proxy/pull/918) Fix log header output (@JoelSpeed) - [#911](https://github.com/oauth2-proxy/oauth2-proxy/pull/911) Validate provider type on startup. - [#869](https://github.com/oauth2-proxy/oauth2-proxy/pull/869) Streamline provider interface method names and signatures (@NickMeves) +- [#849](https://github.com/oauth2-proxy/oauth2-proxy/pull/849) Support group authorization on `oauth2/auth` endpoint via `allowed_group` & `allowed_groups` querystring parameters (@NickMeves) - [#906](https://github.com/oauth2-proxy/oauth2-proxy/pull/906) Set up v6.1.x versioned documentation as default documentation (@JoelSpeed) - [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Remove v5 legacy sessions support (@NickMeves) - [#904](https://github.com/oauth2-proxy/oauth2-proxy/pull/904) Set `skip-auth-strip-headers` to `true` by default (@NickMeves) diff --git a/oauthproxy.go b/oauthproxy.go index 693d5a7a..81152082 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -744,7 +744,7 @@ func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { case path == p.OAuthCallbackPath: p.OAuthCallback(rw, req) case path == p.AuthOnlyPath: - p.AuthenticateOnly(rw, req) + p.AuthOnly(rw, req) case path == p.UserInfoPath: p.UserInfo(rw, req) default: @@ -925,14 +925,22 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { } } -// AuthenticateOnly checks whether the user is currently logged in -func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) { +// AuthOnly checks whether the user is currently logged in (both authentication +// and optional authorization via `allowed_groups` querystring). +func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) { session, err := p.getAuthenticatedSession(rw, req) if err != nil { http.Error(rw, "unauthorized request", http.StatusUnauthorized) return } + // Allow secondary group restrictions based on the `allowed_group` or + // `allowed_groups` querystring parameter + if !checkAllowedGroups(req, session) { + http.Error(rw, "unauthorized request", http.StatusUnauthorized) + return + } + // we are authenticated p.addHeadersForProxying(rw, req, session) p.headersChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { @@ -1016,6 +1024,42 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R return session, nil } +func checkAllowedGroups(req *http.Request, session *sessionsapi.SessionState) bool { + allowedGroups := extractAllowedGroups(req) + if len(allowedGroups) == 0 { + return true + } + + for _, group := range session.Groups { + if _, ok := allowedGroups[group]; ok { + return true + } + } + + return false +} + +func extractAllowedGroups(req *http.Request) map[string]struct{} { + groups := map[string]struct{}{} + query := req.URL.Query() + + // multi-key singular support + if multiGroups, ok := query["allowed_group"]; ok { + for _, group := range multiGroups { + groups[group] = struct{}{} + } + } + + // single key plural comma delimited support + for _, group := range strings.Split(query.Get("allowed_groups"), ",") { + if group != "" { + groups[group] = struct{}{} + } + } + + return groups +} + // addHeadersForProxying adds the appropriate headers the request / response for proxying func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) { if session.Email == "" { diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 866109ca..3a607b94 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1197,18 +1197,20 @@ func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { assert.Equal(t, http.StatusUnauthorized, test.rw.Code) } -func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) (*ProcessCookieTest, error) { +func NewAuthOnlyEndpointTest(querystring string, modifiers ...OptionsModifier) (*ProcessCookieTest, error) { pcTest, err := NewProcessCookieTestWithOptionsModifiers(modifiers...) if err != nil { return nil, err } - pcTest.req, _ = http.NewRequest("GET", - pcTest.opts.ProxyPrefix+"/auth", nil) + pcTest.req, _ = http.NewRequest( + "GET", + fmt.Sprintf("%s/auth%s", pcTest.opts.ProxyPrefix, querystring), + nil) return pcTest, nil } func TestAuthOnlyEndpointAccepted(t *testing.T) { - test, err := NewAuthOnlyEndpointTest() + test, err := NewAuthOnlyEndpointTest("") if err != nil { t.Fatal(err) } @@ -1226,7 +1228,7 @@ func TestAuthOnlyEndpointAccepted(t *testing.T) { } func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { - test, err := NewAuthOnlyEndpointTest() + test, err := NewAuthOnlyEndpointTest("") if err != nil { t.Fatal(err) } @@ -1238,7 +1240,7 @@ func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { } func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { - test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) { + test, err := NewAuthOnlyEndpointTest("", func(opts *options.Options) { opts.Cookie.Expire = time.Duration(24) * time.Hour }) if err != nil { @@ -1258,7 +1260,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { } func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { - test, err := NewAuthOnlyEndpointTest() + test, err := NewAuthOnlyEndpointTest("") if err != nil { t.Fatal(err) } @@ -1960,7 +1962,7 @@ func TestGetJwtSession(t *testing.T) { verifier := oidc.NewVerifier("https://issuer.example.com", keyset, &oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) - test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) { + test, err := NewAuthOnlyEndpointTest("", func(opts *options.Options) { opts.InjectRequestHeaders = []options.Header{ { Name: "Authorization", @@ -2028,7 +2030,6 @@ func TestGetJwtSession(t *testing.T) { }, }, } - opts.SkipJwtBearerTokens = true opts.SetJWTBearerVerifiers(append(opts.GetJWTBearerVerifiers(), verifier)) }) @@ -2692,32 +2693,106 @@ func TestProxyAllowedGroups(t *testing.T) { } func TestAuthOnlyAllowedGroups(t *testing.T) { - tests := []struct { + testCases := []struct { name string allowedGroups []string groups []string + querystring string expectUnauthorized bool }{ - {"NoAllowedGroups", []string{}, []string{}, false}, - {"NoAllowedGroupsUserHasGroups", []string{}, []string{"a", "b"}, false}, - {"UserInAllowedGroup", []string{"a"}, []string{"a", "b"}, false}, - {"UserNotInAllowedGroup", []string{"a"}, []string{"c"}, true}, + { + name: "NoAllowedGroups", + allowedGroups: []string{}, + groups: []string{}, + querystring: "", + expectUnauthorized: false, + }, + { + name: "NoAllowedGroupsUserHasGroups", + allowedGroups: []string{}, + groups: []string{"a", "b"}, + querystring: "", + expectUnauthorized: false, + }, + { + name: "UserInAllowedGroup", + allowedGroups: []string{"a"}, + groups: []string{"a", "b"}, + querystring: "", + expectUnauthorized: false, + }, + { + name: "UserNotInAllowedGroup", + allowedGroups: []string{"a"}, + groups: []string{"c"}, + querystring: "", + expectUnauthorized: true, + }, + { + name: "UserInQuerystringGroup", + allowedGroups: []string{"a", "b"}, + groups: []string{"a", "c"}, + querystring: "?allowed_group=a", + expectUnauthorized: false, + }, + { + name: "UserInOnlyQuerystringGroup", + allowedGroups: []string{}, + groups: []string{"a", "c"}, + querystring: "?allowed_groups=a,b", + expectUnauthorized: false, + }, + { + name: "UserInMultiParamQuerystringGroup", + allowedGroups: []string{"a", "b"}, + groups: []string{"b"}, + querystring: "?allowed_group=a&allowed_group=b", + expectUnauthorized: false, + }, + { + name: "UserInDelimitedQuerystringGroup", + allowedGroups: []string{"a", "b", "c"}, + groups: []string{"c"}, + querystring: "?allowed_groups=a,c", + expectUnauthorized: false, + }, + { + name: "UserNotInQuerystringGroup", + allowedGroups: []string{}, + groups: []string{"c"}, + querystring: "?allowed_group=a&allowed_group=b", + expectUnauthorized: true, + }, + { + name: "UserInConfigGroupNotInQuerystringGroup", + allowedGroups: []string{"a", "b", "c"}, + groups: []string{"c"}, + querystring: "?allowed_group=a&allowed_group=b", + expectUnauthorized: true, + }, + { + name: "UserInQuerystringGroupNotInConfigGroup", + allowedGroups: []string{"a", "b"}, + groups: []string{"c"}, + querystring: "?allowed_groups=b,c", + expectUnauthorized: true, + }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { emailAddress := "test" created := time.Now() session := &sessions.SessionState{ - Groups: tt.groups, + Groups: tc.groups, Email: emailAddress, AccessToken: "oauth_token", CreatedAt: &created, } - test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) { - opts.AllowedGroups = tt.allowedGroups + test, err := NewAuthOnlyEndpointTest(tc.querystring, func(opts *options.Options) { + opts.AllowedGroups = tc.allowedGroups }) if err != nil { t.Fatal(err) @@ -2728,7 +2803,7 @@ func TestAuthOnlyAllowedGroups(t *testing.T) { test.proxy.ServeHTTP(test.rw, test.req) - if tt.expectUnauthorized { + if tc.expectUnauthorized { assert.Equal(t, http.StatusUnauthorized, test.rw.Code) } else { assert.Equal(t, http.StatusAccepted, test.rw.Code) From 44d83e5f95f7964c49ae81e93c2c71ad404b0312 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Tue, 17 Nov 2020 19:03:41 -0800 Subject: [PATCH 08/28] Use StatusForbidden to prevent infinite redirects --- oauthproxy.go | 4 ++-- oauthproxy_test.go | 36 ++++++++++++++++-------------------- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index 81152082..0520593b 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -930,14 +930,14 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) { session, err := p.getAuthenticatedSession(rw, req) if err != nil { - http.Error(rw, "unauthorized request", http.StatusUnauthorized) + http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } // Allow secondary group restrictions based on the `allowed_group` or // `allowed_groups` querystring parameter if !checkAllowedGroups(req, session) { - http.Error(rw, "unauthorized request", http.StatusUnauthorized) + http.Error(rw, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 3a607b94..3cc19447 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1236,7 +1236,7 @@ func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) bodyBytes, _ := ioutil.ReadAll(test.rw.Body) - assert.Equal(t, "unauthorized request\n", string(bodyBytes)) + assert.Equal(t, "Unauthorized\n", string(bodyBytes)) } func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { @@ -1256,7 +1256,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) bodyBytes, _ := ioutil.ReadAll(test.rw.Body) - assert.Equal(t, "unauthorized request\n", string(bodyBytes)) + assert.Equal(t, "Unauthorized\n", string(bodyBytes)) } func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { @@ -1275,7 +1275,7 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) bodyBytes, _ := ioutil.ReadAll(test.rw.Body) - assert.Equal(t, "unauthorized request\n", string(bodyBytes)) + assert.Equal(t, "Unauthorized\n", string(bodyBytes)) } func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { @@ -2698,84 +2698,84 @@ func TestAuthOnlyAllowedGroups(t *testing.T) { allowedGroups []string groups []string querystring string - expectUnauthorized bool + expectedStatusCode int }{ { name: "NoAllowedGroups", allowedGroups: []string{}, groups: []string{}, querystring: "", - expectUnauthorized: false, + expectedStatusCode: http.StatusAccepted, }, { name: "NoAllowedGroupsUserHasGroups", allowedGroups: []string{}, groups: []string{"a", "b"}, querystring: "", - expectUnauthorized: false, + expectedStatusCode: http.StatusAccepted, }, { name: "UserInAllowedGroup", allowedGroups: []string{"a"}, groups: []string{"a", "b"}, querystring: "", - expectUnauthorized: false, + expectedStatusCode: http.StatusAccepted, }, { name: "UserNotInAllowedGroup", allowedGroups: []string{"a"}, groups: []string{"c"}, querystring: "", - expectUnauthorized: true, + expectedStatusCode: http.StatusUnauthorized, }, { name: "UserInQuerystringGroup", allowedGroups: []string{"a", "b"}, groups: []string{"a", "c"}, querystring: "?allowed_group=a", - expectUnauthorized: false, + expectedStatusCode: http.StatusAccepted, }, { name: "UserInOnlyQuerystringGroup", allowedGroups: []string{}, groups: []string{"a", "c"}, querystring: "?allowed_groups=a,b", - expectUnauthorized: false, + expectedStatusCode: http.StatusAccepted, }, { name: "UserInMultiParamQuerystringGroup", allowedGroups: []string{"a", "b"}, groups: []string{"b"}, querystring: "?allowed_group=a&allowed_group=b", - expectUnauthorized: false, + expectedStatusCode: http.StatusAccepted, }, { name: "UserInDelimitedQuerystringGroup", allowedGroups: []string{"a", "b", "c"}, groups: []string{"c"}, querystring: "?allowed_groups=a,c", - expectUnauthorized: false, + expectedStatusCode: http.StatusAccepted, }, { name: "UserNotInQuerystringGroup", allowedGroups: []string{}, groups: []string{"c"}, querystring: "?allowed_group=a&allowed_group=b", - expectUnauthorized: true, + expectedStatusCode: http.StatusForbidden, }, { name: "UserInConfigGroupNotInQuerystringGroup", allowedGroups: []string{"a", "b", "c"}, groups: []string{"c"}, querystring: "?allowed_group=a&allowed_group=b", - expectUnauthorized: true, + expectedStatusCode: http.StatusForbidden, }, { name: "UserInQuerystringGroupNotInConfigGroup", allowedGroups: []string{"a", "b"}, groups: []string{"c"}, querystring: "?allowed_groups=b,c", - expectUnauthorized: true, + expectedStatusCode: http.StatusUnauthorized, }, } @@ -2803,11 +2803,7 @@ func TestAuthOnlyAllowedGroups(t *testing.T) { test.proxy.ServeHTTP(test.rw, test.req) - if tc.expectUnauthorized { - assert.Equal(t, http.StatusUnauthorized, test.rw.Code) - } else { - assert.Equal(t, http.StatusAccepted, test.rw.Code) - } + assert.Equal(t, tc.expectedStatusCode, test.rw.Code) }) } } From 025056cba0a74f32db7148ad0f8fbd27bfb386e1 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Thu, 19 Nov 2020 20:27:28 -0800 Subject: [PATCH 09/28] Move AuthOnly authorize logic to a dedicated method --- CHANGELOG.md | 4 ++-- oauthproxy.go | 24 ++++++++++++++++++------ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ef5a77aa..504222d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,13 +54,13 @@ - [#938](https://github.com/oauth2-proxy/oauth2-proxy/pull/938) Cleanup missed provider renaming refactor methods (@NickMeves) - [#816](https://github.com/oauth2-proxy/oauth2-proxy/pull/816) (via [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936)) Support non-list group claims (@loafoe) - [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) Refactor OIDC Provider and support groups from Profile URL (@NickMeves) +- [#869](https://github.com/oauth2-proxy/oauth2-proxy/pull/869) Streamline provider interface method names and signatures (@NickMeves) +- [#849](https://github.com/oauth2-proxy/oauth2-proxy/pull/849) Support group authorization on `oauth2/auth` endpoint via `allowed_group` & `allowed_groups` querystring parameters (@NickMeves) - [#925](https://github.com/oauth2-proxy/oauth2-proxy/pull/925) Fix basic auth legacy header conversion (@JoelSpeed) - [#916](https://github.com/oauth2-proxy/oauth2-proxy/pull/916) Add AlphaOptions struct to prepare for alpha config loading (@JoelSpeed) - [#923](https://github.com/oauth2-proxy/oauth2-proxy/pull/923) Support TLS 1.3 (@aajisaka) - [#918](https://github.com/oauth2-proxy/oauth2-proxy/pull/918) Fix log header output (@JoelSpeed) - [#911](https://github.com/oauth2-proxy/oauth2-proxy/pull/911) Validate provider type on startup. -- [#869](https://github.com/oauth2-proxy/oauth2-proxy/pull/869) Streamline provider interface method names and signatures (@NickMeves) -- [#849](https://github.com/oauth2-proxy/oauth2-proxy/pull/849) Support group authorization on `oauth2/auth` endpoint via `allowed_group` & `allowed_groups` querystring parameters (@NickMeves) - [#906](https://github.com/oauth2-proxy/oauth2-proxy/pull/906) Set up v6.1.x versioned documentation as default documentation (@JoelSpeed) - [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Remove v5 legacy sessions support (@NickMeves) - [#904](https://github.com/oauth2-proxy/oauth2-proxy/pull/904) Set `skip-auth-strip-headers` to `true` by default (@NickMeves) diff --git a/oauthproxy.go b/oauthproxy.go index 0520593b..c11b83b1 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -926,7 +926,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { } // AuthOnly checks whether the user is currently logged in (both authentication -// and optional authorization via `allowed_groups` querystring). +// and optional authorization). func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) { session, err := p.getAuthenticatedSession(rw, req) if err != nil { @@ -934,9 +934,9 @@ func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) { return } - // Allow secondary group restrictions based on the `allowed_group` or - // `allowed_groups` querystring parameter - if !checkAllowedGroups(req, session) { + // Unauthorized cases need to return 403 to prevent infinite redirects with + // subrequest architectures + if !authOnlyAuthorize(req, session) { http.Error(rw, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } @@ -1024,13 +1024,25 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R return session, nil } -func checkAllowedGroups(req *http.Request, session *sessionsapi.SessionState) bool { +// authOnlyAuthorize handles special authorization logic that is only done +// on the AuthOnly endpoint for use with Nginx subrequest architectures. +func authOnlyAuthorize(req *http.Request, s *sessionsapi.SessionState) bool { + // Allow secondary group restrictions based on the `allowed_group` or + // `allowed_groups` querystring parameter + if !checkAllowedGroups(req, s) { + return false + } + + return true +} + +func checkAllowedGroups(req *http.Request, s *sessionsapi.SessionState) bool { allowedGroups := extractAllowedGroups(req) if len(allowedGroups) == 0 { return true } - for _, group := range session.Groups { + for _, group := range s.Groups { if _, ok := allowedGroups[group]; ok { return true } From 65e15f24c1bf275919e0961b29ca1ac973a63780 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Fri, 27 Nov 2020 09:07:21 -0800 Subject: [PATCH 10/28] Support only allowed_groups querystring --- CHANGELOG.md | 8 ++++---- oauthproxy.go | 22 ++++++++-------------- oauthproxy_test.go | 20 ++++++++++---------- 3 files changed, 22 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 504222d9..10b6aa88 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,7 @@ - [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) `--user-id-claim` option is deprecated and replaced by `--oidc-email-claim` - [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Gitlab projects needs a Gitlab application with the extra `read_api` enabled - [#849](https://github.com/oauth2-proxy/oauth2-proxy/pull/849) `/oauth2/auth` `allowed_groups` querystring parameter can be paired with the `allowed-groups` configuration option. - - In this scenario, the user's group must be in both lists to not get a 401 response code. - - The `allowed_group` querystring parameter can be specified multiple times to support multiple groups. + - In this scenario, the user's group must be in both lists to not get a 401 or 403 response code. - The `allowed_groups` querystring parameter can specify multiple comma delimited groups. - [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication. - [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped. @@ -23,7 +22,8 @@ - [#575](https://github.com/oauth2-proxy/oauth2-proxy/pull/575) Sessions from v5.1.1 or earlier will no longer validate since they were not signed with SHA1. - Sessions from v6.0.0 or later had a graceful conversion to SHA256 that resulted in no reauthentication - Upgrading from v5.1.1 or earlier will result in a reauthentication -- [#616](https://github.com/oauth2-proxy/oauth2-proxy/pull/616) Ensure you have configured oauth2-proxy to use the `groups` scope. The user may be logged out initially as they may not currently have the `groups` claim however after going back through login process wil be authenticated. +- [#616](https://github.com/oauth2-proxy/oauth2-proxy/pull/616) Ensure you have configured oauth2-proxy to use the `groups` scope. + - The user may be logged out initially as they may not currently have the `groups` claim however after going back through login process wil be authenticated. - [#839](https://github.com/oauth2-proxy/oauth2-proxy/pull/839) Enables complex data structures for group claim entries, which are output as Json by default. ## Breaking Changes @@ -55,7 +55,7 @@ - [#816](https://github.com/oauth2-proxy/oauth2-proxy/pull/816) (via [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936)) Support non-list group claims (@loafoe) - [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) Refactor OIDC Provider and support groups from Profile URL (@NickMeves) - [#869](https://github.com/oauth2-proxy/oauth2-proxy/pull/869) Streamline provider interface method names and signatures (@NickMeves) -- [#849](https://github.com/oauth2-proxy/oauth2-proxy/pull/849) Support group authorization on `oauth2/auth` endpoint via `allowed_group` & `allowed_groups` querystring parameters (@NickMeves) +- [#849](https://github.com/oauth2-proxy/oauth2-proxy/pull/849) Support group authorization on `oauth2/auth` endpoint via `allowed_groups` querystring (@NickMeves) - [#925](https://github.com/oauth2-proxy/oauth2-proxy/pull/925) Fix basic auth legacy header conversion (@JoelSpeed) - [#916](https://github.com/oauth2-proxy/oauth2-proxy/pull/916) Add AlphaOptions struct to prepare for alpha config loading (@JoelSpeed) - [#923](https://github.com/oauth2-proxy/oauth2-proxy/pull/923) Support TLS 1.3 (@aajisaka) diff --git a/oauthproxy.go b/oauthproxy.go index c11b83b1..999e1fbb 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -1027,8 +1027,8 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R // authOnlyAuthorize handles special authorization logic that is only done // on the AuthOnly endpoint for use with Nginx subrequest architectures. func authOnlyAuthorize(req *http.Request, s *sessionsapi.SessionState) bool { - // Allow secondary group restrictions based on the `allowed_group` or - // `allowed_groups` querystring parameter + // Allow secondary group restrictions based on the `allowed_groups` + // querystring parameter if !checkAllowedGroups(req, s) { return false } @@ -1053,19 +1053,13 @@ func checkAllowedGroups(req *http.Request, s *sessionsapi.SessionState) bool { func extractAllowedGroups(req *http.Request) map[string]struct{} { groups := map[string]struct{}{} + query := req.URL.Query() - - // multi-key singular support - if multiGroups, ok := query["allowed_group"]; ok { - for _, group := range multiGroups { - groups[group] = struct{}{} - } - } - - // single key plural comma delimited support - for _, group := range strings.Split(query.Get("allowed_groups"), ",") { - if group != "" { - groups[group] = struct{}{} + for _, allowedGroups := range query["allowed_groups"] { + for _, group := range strings.Split(allowedGroups, ",") { + if group != "" { + groups[group] = struct{}{} + } } } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 3cc19447..56e046c2 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -2732,7 +2732,14 @@ func TestAuthOnlyAllowedGroups(t *testing.T) { name: "UserInQuerystringGroup", allowedGroups: []string{"a", "b"}, groups: []string{"a", "c"}, - querystring: "?allowed_group=a", + querystring: "?allowed_groups=a", + expectedStatusCode: http.StatusAccepted, + }, + { + name: "UserInMultiParamQuerystringGroup", + allowedGroups: []string{"a", "b"}, + groups: []string{"b"}, + querystring: "?allowed_groups=a&allowed_groups=b,d", expectedStatusCode: http.StatusAccepted, }, { @@ -2742,13 +2749,6 @@ func TestAuthOnlyAllowedGroups(t *testing.T) { querystring: "?allowed_groups=a,b", expectedStatusCode: http.StatusAccepted, }, - { - name: "UserInMultiParamQuerystringGroup", - allowedGroups: []string{"a", "b"}, - groups: []string{"b"}, - querystring: "?allowed_group=a&allowed_group=b", - expectedStatusCode: http.StatusAccepted, - }, { name: "UserInDelimitedQuerystringGroup", allowedGroups: []string{"a", "b", "c"}, @@ -2760,14 +2760,14 @@ func TestAuthOnlyAllowedGroups(t *testing.T) { name: "UserNotInQuerystringGroup", allowedGroups: []string{}, groups: []string{"c"}, - querystring: "?allowed_group=a&allowed_group=b", + querystring: "?allowed_groups=a,b", expectedStatusCode: http.StatusForbidden, }, { name: "UserInConfigGroupNotInQuerystringGroup", allowedGroups: []string{"a", "b", "c"}, groups: []string{"c"}, - querystring: "?allowed_group=a&allowed_group=b", + querystring: "?allowed_groups=a,b", expectedStatusCode: http.StatusForbidden, }, { From 753f6c548acd9a5c9563bf7f061d575733af25d7 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Fri, 27 Nov 2020 10:45:55 -0800 Subject: [PATCH 11/28] Add a detailed allowed_groups example to Important Notes --- CHANGELOG.md | 7 ++++++- oauthproxy.go | 5 +++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 10b6aa88..edb631d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,13 @@ - [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) `--user-id-claim` option is deprecated and replaced by `--oidc-email-claim` - [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Gitlab projects needs a Gitlab application with the extra `read_api` enabled - [#849](https://github.com/oauth2-proxy/oauth2-proxy/pull/849) `/oauth2/auth` `allowed_groups` querystring parameter can be paired with the `allowed-groups` configuration option. - - In this scenario, the user's group must be in both lists to not get a 401 or 403 response code. - The `allowed_groups` querystring parameter can specify multiple comma delimited groups. + - In this scenario, the user must have a group (from their multiple groups) present in both lists to not get a 401 or 403 response code. + - Example: + - OAuth2-Proxy globally sets the `allowed_groups` as `engineering`. + - An application using Kubernetes ingress uses the `/oauth2/auth` endpoint with `allowed_groups` querystring set to `backend`. + - A user must have a session with the groups `["engineering", "backend"]` to pass authorization. + - Another user with the groups `["engineering", "frontend"]` would fail the querystring authorization portion. - [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication. - [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped. - [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) The behavior of the Google provider Groups restriction changes with this diff --git a/oauthproxy.go b/oauthproxy.go index 999e1fbb..f97af98b 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -1026,6 +1026,11 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R // authOnlyAuthorize handles special authorization logic that is only done // on the AuthOnly endpoint for use with Nginx subrequest architectures. +// +// TODO (@NickMeves): This method is a placeholder to be extended but currently +// fails the linter. Remove the nolint when functionality expands. +// +//nolint:S1008 func authOnlyAuthorize(req *http.Request, s *sessionsapi.SessionState) bool { // Allow secondary group restrictions based on the `allowed_groups` // querystring parameter From 3369799853efcdd8a7fdb8e9cf3e8a7a2ba49f3b Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 12 Dec 2020 12:57:32 -0800 Subject: [PATCH 12/28] Migrate Keycloak to EnrichSession & support multiple groups --- pkg/apis/options/options.go | 4 +- pkg/validation/options.go | 5 +- providers/keycloak.go | 44 +++---- providers/keycloak_test.go | 232 ++++++++++++++++++++++++------------ 4 files changed, 183 insertions(+), 102 deletions(-) diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index 46cdcedb..c0f91422 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -36,7 +36,7 @@ type Options struct { TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file"` AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` - KeycloakGroup string `flag:"keycloak-group" cfg:"keycloak_group"` + KeycloakGroups []string `flag:"keycloak-group" cfg:"keycloak_groups"` AzureTenant string `flag:"azure-tenant" cfg:"azure_tenant"` BitbucketTeam string `flag:"bitbucket-team" cfg:"bitbucket_team"` BitbucketRepository string `flag:"bitbucket-repository" cfg:"bitbucket_repository"` @@ -181,7 +181,7 @@ func NewFlagSet() *pflag.FlagSet { flagSet.StringSlice("email-domain", []string{}, "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email") flagSet.StringSlice("whitelist-domain", []string{}, "allowed domains for redirection after authentication. Prefix domain with a . to allow subdomains (eg .example.com)") - flagSet.String("keycloak-group", "", "restrict login to members of this group.") + flagSet.StringSlice("keycloak-group", []string{}, "restrict logins to members of these groups (may be given multiple times)") flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.") flagSet.String("bitbucket-team", "", "restrict logins to members of this team") flagSet.String("bitbucket-repository", "", "restrict logins to user with access to this repository") diff --git a/pkg/validation/options.go b/pkg/validation/options.go index 4fc0b0a4..52b0fb69 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -263,7 +263,10 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { p.SetRepo(o.GitHubRepo, o.GitHubToken) p.SetUsers(o.GitHubUsers) case *providers.KeycloakProvider: - p.SetGroup(o.KeycloakGroup) + // Backwards compatibility with `--keycloak-group` option + if len(o.KeycloakGroups) > 0 { + p.SetAllowedGroups(o.KeycloakGroups) + } case *providers.GoogleProvider: if o.GoogleServiceAccountJSON != "" { file, err := os.Open(o.GoogleServiceAccountJSON) diff --git a/providers/keycloak.go b/providers/keycloak.go index 60b3eaca..f6c74880 100644 --- a/providers/keycloak.go +++ b/providers/keycloak.go @@ -2,6 +2,7 @@ package providers import ( "context" + "fmt" "net/url" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" @@ -11,7 +12,6 @@ import ( type KeycloakProvider struct { *ProviderData - Group string } var _ Provider = (*KeycloakProvider)(nil) @@ -59,41 +59,33 @@ func NewKeycloakProvider(p *ProviderData) *KeycloakProvider { return &KeycloakProvider{ProviderData: p} } -func (p *KeycloakProvider) SetGroup(group string) { - p.Group = group -} - -func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { +func (p *KeycloakProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { json, err := requests.New(p.ValidateURL.String()). WithContext(ctx). SetHeader("Authorization", "Bearer "+s.AccessToken). Do(). UnmarshalJSON() if err != nil { - logger.Errorf("failed making request %s", err) - return "", err + logger.Errorf("failed making request %v", err) + return err } - if p.Group != "" { - var groups, err = json.Get("groups").Array() - if err != nil { - logger.Printf("groups not found %s", err) - return "", err - } - - var found = false - for i := range groups { - if groups[i].(string) == p.Group { - found = true - break + groups, err := json.Get("groups").StringArray() + if err != nil { + logger.Errorf("Warning: unable to extract groups from userinfo endpoint: %v", err) + } else { + for _, group := range groups { + if group != "" { + s.Groups = append(s.Groups, group) } } - - if !found { - logger.Printf("group not found, access denied") - return "", nil - } } - return json.Get("email").String() + email, err := json.Get("email").String() + if err != nil { + return fmt.Errorf("unable to extract email from userinfo endpoint: %v", err) + } + s.Email = email + + return nil } diff --git a/providers/keycloak_test.go b/providers/keycloak_test.go index 3f419f2e..d7ae4391 100644 --- a/providers/keycloak_test.go +++ b/providers/keycloak_test.go @@ -2,17 +2,33 @@ package providers import ( "context" + "errors" + "fmt" "net/http" "net/http/httptest" "net/url" "testing" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" "github.com/stretchr/testify/assert" ) -func testKeycloakProvider(hostname, group string) *KeycloakProvider { +const ( + keycloakAccessToken = "eyJKeycloak.eyJAccess.Token" + keycloakUserinfoPath = "/api/v3/user" + + // Userinfo Test Cases + tcUIStandard = "userinfo-standard" + tcUIFail = "userinfo-fail" + tcUISingleGroup = "userinfo-single-group" + tcUIMissingEmail = "userinfo-missing-email" + tcUIMissingGroups = "userinfo-missing-groups" +) + +func testKeycloakProvider(backend *httptest.Server) (*KeycloakProvider, error) { p := NewKeycloakProvider( &ProviderData{ ProviderName: "", @@ -22,38 +38,165 @@ func testKeycloakProvider(hostname, group string) *KeycloakProvider { ValidateURL: &url.URL{}, Scope: ""}) - if group != "" { - p.SetGroup(group) - } + if backend != nil { + bURL, err := url.Parse(backend.URL) + if err != nil { + return nil, err + } + hostname := bURL.Host - if hostname != "" { updateURL(p.Data().LoginURL, hostname) updateURL(p.Data().RedeemURL, hostname) updateURL(p.Data().ProfileURL, hostname) updateURL(p.Data().ValidateURL, hostname) } - return p + + return p, nil } -func testKeycloakBackend(payload string) *httptest.Server { - path := "/api/v3/user" - +func testKeycloakBackend() *httptest.Server { return httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - url := r.URL - if url.Path != path { + if r.URL.Path != keycloakUserinfoPath { w.WriteHeader(404) - } else if !IsAuthorizedInHeader(r.Header) { - w.WriteHeader(403) - } else { + } + + var err error + switch r.URL.Query().Get("testcase") { + case tcUIStandard: w.WriteHeader(200) - w.Write([]byte(payload)) + _, err = w.Write([]byte(` + { + "email": "michael.bland@gsa.gov", + "groups": [ + "test-grp1", + "test-grp2" + ] + } + `)) + case tcUIFail: + w.WriteHeader(500) + case tcUISingleGroup: + w.WriteHeader(200) + _, err = w.Write([]byte(` + { + "email": "michael.bland@gsa.gov", + "groups": ["test-grp1"] + } + `)) + case tcUIMissingEmail: + w.WriteHeader(200) + _, err = w.Write([]byte(` + { + "groups": [ + "test-grp1", + "test-grp2" + ] + } + `)) + case tcUIMissingGroups: + w.WriteHeader(200) + _, err = w.Write([]byte(` + { + "email": "michael.bland@gsa.gov" + } + `)) + default: + w.WriteHeader(404) + } + if err != nil { + panic(err) } })) } +var _ = Describe("Keycloak Provider Tests", func() { + var p *KeycloakProvider + var b *httptest.Server + + BeforeEach(func() { + b = testKeycloakBackend() + + var err error + p, err = testKeycloakProvider(b) + Expect(err).To(BeNil()) + }) + + AfterEach(func() { + b.Close() + }) + + Context("EnrichSession", func() { + type enrichSessionTableInput struct { + testcase string + expectedError error + expectedEmail string + expectedGroups []string + } + + DescribeTable("should return expected results", + func(in enrichSessionTableInput) { + var err error + p.ValidateURL, err = url.Parse( + fmt.Sprintf("%s%s?testcase=%s", b.URL, keycloakUserinfoPath, in.testcase), + ) + Expect(err).To(BeNil()) + + session := &sessions.SessionState{AccessToken: keycloakAccessToken} + err = p.EnrichSession(context.Background(), session) + + if in.expectedError != nil { + Expect(err).To(Equal(in.expectedError)) + } else { + Expect(err).To(BeNil()) + } + + Expect(session.Email).To(Equal(in.expectedEmail)) + + if in.expectedGroups != nil { + Expect(session.Groups).To(Equal(in.expectedGroups)) + } else { + Expect(session.Groups).To(BeNil()) + } + }, + Entry("email and multiple groups", enrichSessionTableInput{ + testcase: tcUIStandard, + expectedError: nil, + expectedEmail: "michael.bland@gsa.gov", + expectedGroups: []string{"test-grp1", "test-grp2"}, + }), + Entry("email and single group", enrichSessionTableInput{ + testcase: tcUISingleGroup, + expectedError: nil, + expectedEmail: "michael.bland@gsa.gov", + expectedGroups: []string{"test-grp1"}, + }), + Entry("email and no groups", enrichSessionTableInput{ + testcase: tcUIMissingGroups, + expectedError: nil, + expectedEmail: "michael.bland@gsa.gov", + expectedGroups: nil, + }), + Entry("missing email", enrichSessionTableInput{ + testcase: tcUIMissingEmail, + expectedError: errors.New( + "unable to extract email from userinfo endpoint: type assertion to string failed"), + expectedEmail: "", + expectedGroups: []string{"test-grp1", "test-grp2"}, + }), + Entry("request failure", enrichSessionTableInput{ + testcase: tcUIFail, + expectedError: errors.New(`unexpected status "500": `), + expectedEmail: "", + expectedGroups: nil, + }), + ) + }) +}) + func TestKeycloakProviderDefaults(t *testing.T) { - p := testKeycloakProvider("", "") + p, err := testKeycloakProvider(nil) + assert.NoError(t, err) assert.NotEqual(t, nil, p) assert.Equal(t, "Keycloak", p.Data().ProviderName) assert.Equal(t, "https://keycloak.org/oauth/authorize", @@ -104,60 +247,3 @@ func TestKeycloakProviderOverrides(t *testing.T) { p.Data().ValidateURL.String()) assert.Equal(t, "profile", p.Data().Scope) } - -func TestKeycloakProviderGetEmailAddress(t *testing.T) { - b := testKeycloakBackend("{\"email\": \"michael.bland@gsa.gov\"}") - defer b.Close() - - bURL, _ := url.Parse(b.URL) - p := testKeycloakProvider(bURL.Host, "") - - session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(context.Background(), session) - assert.Equal(t, nil, err) - assert.Equal(t, "michael.bland@gsa.gov", email) -} - -func TestKeycloakProviderGetEmailAddressAndGroup(t *testing.T) { - b := testKeycloakBackend("{\"email\": \"michael.bland@gsa.gov\", \"groups\": [\"test-grp1\", \"test-grp2\"]}") - defer b.Close() - - bURL, _ := url.Parse(b.URL) - p := testKeycloakProvider(bURL.Host, "test-grp1") - - session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(context.Background(), session) - assert.Equal(t, nil, err) - assert.Equal(t, "michael.bland@gsa.gov", email) -} - -// Note that trying to trigger the "failed building request" case is not -// practical, since the only way it can fail is if the URL fails to parse. -func TestKeycloakProviderGetEmailAddressFailedRequest(t *testing.T) { - b := testKeycloakBackend("unused payload") - defer b.Close() - - bURL, _ := url.Parse(b.URL) - p := testKeycloakProvider(bURL.Host, "") - - // We'll trigger a request failure by using an unexpected access - // token. Alternatively, we could allow the parsing of the payload as - // JSON to fail. - session := &sessions.SessionState{AccessToken: "unexpected_access_token"} - email, err := p.GetEmailAddress(context.Background(), session) - assert.NotEqual(t, nil, err) - assert.Equal(t, "", email) -} - -func TestKeycloakProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { - b := testKeycloakBackend("{\"foo\": \"bar\"}") - defer b.Close() - - bURL, _ := url.Parse(b.URL) - p := testKeycloakProvider(bURL.Host, "") - - session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(context.Background(), session) - assert.NotEqual(t, nil, err) - assert.Equal(t, "", email) -} From 0886f8035cdf877328cbf2bc0f05d50848922ab5 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 12 Dec 2020 13:14:57 -0800 Subject: [PATCH 13/28] Move all Keycloak unit tests to Ginkgo --- providers/keycloak_test.go | 236 +++++++++++++++++-------------------- 1 file changed, 110 insertions(+), 126 deletions(-) diff --git a/providers/keycloak_test.go b/providers/keycloak_test.go index d7ae4391..b3920810 100644 --- a/providers/keycloak_test.go +++ b/providers/keycloak_test.go @@ -7,20 +7,18 @@ import ( "net/http" "net/http/httptest" "net/url" - "testing" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" - "github.com/stretchr/testify/assert" ) const ( keycloakAccessToken = "eyJKeycloak.eyJAccess.Token" keycloakUserinfoPath = "/api/v3/user" - // Userinfo Test Cases + // Userinfo Test Cases querystring toggles tcUIStandard = "userinfo-standard" tcUIFail = "userinfo-fail" tcUISingleGroup = "userinfo-single-group" @@ -111,139 +109,125 @@ func testKeycloakBackend() *httptest.Server { } var _ = Describe("Keycloak Provider Tests", func() { - var p *KeycloakProvider - var b *httptest.Server + Context("New Provider Init", func() { + It("uses defaults", func() { + providerData := NewKeycloakProvider(&ProviderData{}).Data() + Expect(providerData.ProviderName).To(Equal("Keycloak")) + Expect(providerData.LoginURL.String()).To(Equal("https://keycloak.org/oauth/authorize")) + Expect(providerData.RedeemURL.String()).To(Equal("https://keycloak.org/oauth/token")) + Expect(providerData.ProfileURL.String()).To(Equal("")) + Expect(providerData.ValidateURL.String()).To(Equal("https://keycloak.org/api/v3/user")) + Expect(providerData.Scope).To(Equal("api")) + }) - BeforeEach(func() { - b = testKeycloakBackend() + It("overrides defaults", func() { + p := NewKeycloakProvider( + &ProviderData{ + LoginURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/auth"}, + RedeemURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/token"}, + ValidateURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/api/v3/user"}, + Scope: "profile"}) + providerData := p.Data() - var err error - p, err = testKeycloakProvider(b) - Expect(err).To(BeNil()) + Expect(providerData.ProviderName).To(Equal("Keycloak")) + Expect(providerData.LoginURL.String()).To(Equal("https://example.com/oauth/auth")) + Expect(providerData.RedeemURL.String()).To(Equal("https://example.com/oauth/token")) + Expect(providerData.ProfileURL.String()).To(Equal("")) + Expect(providerData.ValidateURL.String()).To(Equal("https://example.com/api/v3/user")) + Expect(providerData.Scope).To(Equal("profile")) + }) }) - AfterEach(func() { - b.Close() - }) + Context("With a test HTTP Server & Provider", func() { + var p *KeycloakProvider + var b *httptest.Server - Context("EnrichSession", func() { - type enrichSessionTableInput struct { - testcase string - expectedError error - expectedEmail string - expectedGroups []string - } + BeforeEach(func() { + b = testKeycloakBackend() - DescribeTable("should return expected results", - func(in enrichSessionTableInput) { - var err error - p.ValidateURL, err = url.Parse( - fmt.Sprintf("%s%s?testcase=%s", b.URL, keycloakUserinfoPath, in.testcase), - ) - Expect(err).To(BeNil()) + var err error + p, err = testKeycloakProvider(b) + Expect(err).To(BeNil()) + }) - session := &sessions.SessionState{AccessToken: keycloakAccessToken} - err = p.EnrichSession(context.Background(), session) + AfterEach(func() { + b.Close() + }) - if in.expectedError != nil { - Expect(err).To(Equal(in.expectedError)) - } else { + Context("EnrichSession", func() { + type enrichSessionTableInput struct { + testcase string + expectedError error + expectedEmail string + expectedGroups []string + } + + DescribeTable("should return expected results", + func(in enrichSessionTableInput) { + var err error + p.ValidateURL, err = url.Parse( + fmt.Sprintf("%s%s?testcase=%s", b.URL, keycloakUserinfoPath, in.testcase), + ) Expect(err).To(BeNil()) - } - Expect(session.Email).To(Equal(in.expectedEmail)) + session := &sessions.SessionState{AccessToken: keycloakAccessToken} + err = p.EnrichSession(context.Background(), session) - if in.expectedGroups != nil { - Expect(session.Groups).To(Equal(in.expectedGroups)) - } else { - Expect(session.Groups).To(BeNil()) - } - }, - Entry("email and multiple groups", enrichSessionTableInput{ - testcase: tcUIStandard, - expectedError: nil, - expectedEmail: "michael.bland@gsa.gov", - expectedGroups: []string{"test-grp1", "test-grp2"}, - }), - Entry("email and single group", enrichSessionTableInput{ - testcase: tcUISingleGroup, - expectedError: nil, - expectedEmail: "michael.bland@gsa.gov", - expectedGroups: []string{"test-grp1"}, - }), - Entry("email and no groups", enrichSessionTableInput{ - testcase: tcUIMissingGroups, - expectedError: nil, - expectedEmail: "michael.bland@gsa.gov", - expectedGroups: nil, - }), - Entry("missing email", enrichSessionTableInput{ - testcase: tcUIMissingEmail, - expectedError: errors.New( - "unable to extract email from userinfo endpoint: type assertion to string failed"), - expectedEmail: "", - expectedGroups: []string{"test-grp1", "test-grp2"}, - }), - Entry("request failure", enrichSessionTableInput{ - testcase: tcUIFail, - expectedError: errors.New(`unexpected status "500": `), - expectedEmail: "", - expectedGroups: nil, - }), - ) + if in.expectedError != nil { + Expect(err).To(Equal(in.expectedError)) + } else { + Expect(err).To(BeNil()) + } + + Expect(session.Email).To(Equal(in.expectedEmail)) + + if in.expectedGroups != nil { + Expect(session.Groups).To(Equal(in.expectedGroups)) + } else { + Expect(session.Groups).To(BeNil()) + } + }, + Entry("email and multiple groups", enrichSessionTableInput{ + testcase: tcUIStandard, + expectedError: nil, + expectedEmail: "michael.bland@gsa.gov", + expectedGroups: []string{"test-grp1", "test-grp2"}, + }), + Entry("email and single group", enrichSessionTableInput{ + testcase: tcUISingleGroup, + expectedError: nil, + expectedEmail: "michael.bland@gsa.gov", + expectedGroups: []string{"test-grp1"}, + }), + Entry("email and no groups", enrichSessionTableInput{ + testcase: tcUIMissingGroups, + expectedError: nil, + expectedEmail: "michael.bland@gsa.gov", + expectedGroups: nil, + }), + Entry("missing email", enrichSessionTableInput{ + testcase: tcUIMissingEmail, + expectedError: errors.New( + "unable to extract email from userinfo endpoint: type assertion to string failed"), + expectedEmail: "", + expectedGroups: []string{"test-grp1", "test-grp2"}, + }), + Entry("request failure", enrichSessionTableInput{ + testcase: tcUIFail, + expectedError: errors.New(`unexpected status "500": `), + expectedEmail: "", + expectedGroups: nil, + }), + ) + }) }) }) - -func TestKeycloakProviderDefaults(t *testing.T) { - p, err := testKeycloakProvider(nil) - assert.NoError(t, err) - assert.NotEqual(t, nil, p) - assert.Equal(t, "Keycloak", p.Data().ProviderName) - assert.Equal(t, "https://keycloak.org/oauth/authorize", - p.Data().LoginURL.String()) - assert.Equal(t, "https://keycloak.org/oauth/token", - p.Data().RedeemURL.String()) - assert.Equal(t, "https://keycloak.org/api/v3/user", - p.Data().ValidateURL.String()) - assert.Equal(t, "api", p.Data().Scope) -} - -func TestNewKeycloakProvider(t *testing.T) { - g := NewWithT(t) - - // Test that defaults are set when calling for a new provider with nothing set - providerData := NewKeycloakProvider(&ProviderData{}).Data() - g.Expect(providerData.ProviderName).To(Equal("Keycloak")) - g.Expect(providerData.LoginURL.String()).To(Equal("https://keycloak.org/oauth/authorize")) - g.Expect(providerData.RedeemURL.String()).To(Equal("https://keycloak.org/oauth/token")) - g.Expect(providerData.ProfileURL.String()).To(Equal("")) - g.Expect(providerData.ValidateURL.String()).To(Equal("https://keycloak.org/api/v3/user")) - g.Expect(providerData.Scope).To(Equal("api")) -} - -func TestKeycloakProviderOverrides(t *testing.T) { - p := NewKeycloakProvider( - &ProviderData{ - LoginURL: &url.URL{ - Scheme: "https", - Host: "example.com", - Path: "/oauth/auth"}, - RedeemURL: &url.URL{ - Scheme: "https", - Host: "example.com", - Path: "/oauth/token"}, - ValidateURL: &url.URL{ - Scheme: "https", - Host: "example.com", - Path: "/api/v3/user"}, - Scope: "profile"}) - assert.NotEqual(t, nil, p) - assert.Equal(t, "Keycloak", p.Data().ProviderName) - assert.Equal(t, "https://example.com/oauth/auth", - p.Data().LoginURL.String()) - assert.Equal(t, "https://example.com/oauth/token", - p.Data().RedeemURL.String()) - assert.Equal(t, "https://example.com/api/v3/user", - p.Data().ValidateURL.String()) - assert.Equal(t, "profile", p.Data().Scope) -} From 138a6b128a95f658da7920f523f840f7bd09c0c6 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 12 Dec 2020 13:22:15 -0800 Subject: [PATCH 14/28] Use ProfileURL for userinfo EnrichSession calls in Keycloak --- providers/keycloak.go | 11 ++++++++++- providers/keycloak_test.go | 8 ++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/providers/keycloak.go b/providers/keycloak.go index f6c74880..66eda948 100644 --- a/providers/keycloak.go +++ b/providers/keycloak.go @@ -47,6 +47,7 @@ var ( } ) +// NewKeycloakProvider creates a KeyCloakProvider using the passed ProviderData func NewKeycloakProvider(p *ProviderData) *KeycloakProvider { p.setProviderDefaults(providerDefaults{ name: keycloakProviderName, @@ -59,8 +60,16 @@ func NewKeycloakProvider(p *ProviderData) *KeycloakProvider { return &KeycloakProvider{ProviderData: p} } +// EnrichSession uses the Keycloak userinfo endpoint to populate the session's +// email and groups. func (p *KeycloakProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { - json, err := requests.New(p.ValidateURL.String()). + // Fallback to ValidateURL if ProfileURL not set for legacy compatibility + userinfoURL := p.ValidateURL.String() + if p.ProfileURL != nil { + userinfoURL = p.ProfileURL.String() + } + + json, err := requests.New(userinfoURL). WithContext(ctx). SetHeader("Authorization", "Bearer "+s.AccessToken). Do(). diff --git a/providers/keycloak_test.go b/providers/keycloak_test.go index b3920810..7c0b457b 100644 --- a/providers/keycloak_test.go +++ b/providers/keycloak_test.go @@ -131,6 +131,10 @@ var _ = Describe("Keycloak Provider Tests", func() { Scheme: "https", Host: "example.com", Path: "/oauth/token"}, + ProfileURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/api/v3/user"}, ValidateURL: &url.URL{ Scheme: "https", Host: "example.com", @@ -141,7 +145,7 @@ var _ = Describe("Keycloak Provider Tests", func() { Expect(providerData.ProviderName).To(Equal("Keycloak")) Expect(providerData.LoginURL.String()).To(Equal("https://example.com/oauth/auth")) Expect(providerData.RedeemURL.String()).To(Equal("https://example.com/oauth/token")) - Expect(providerData.ProfileURL.String()).To(Equal("")) + Expect(providerData.ProfileURL.String()).To(Equal("https://example.com/api/v3/user")) Expect(providerData.ValidateURL.String()).To(Equal("https://example.com/api/v3/user")) Expect(providerData.Scope).To(Equal("profile")) }) @@ -174,7 +178,7 @@ var _ = Describe("Keycloak Provider Tests", func() { DescribeTable("should return expected results", func(in enrichSessionTableInput) { var err error - p.ValidateURL, err = url.Parse( + p.ProfileURL, err = url.Parse( fmt.Sprintf("%s%s?testcase=%s", b.URL, keycloakUserinfoPath, in.testcase), ) Expect(err).To(BeNil()) From f07a5630f1138ce8364edda3de585a907c44d210 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 12 Dec 2020 13:50:34 -0800 Subject: [PATCH 15/28] Update Keycloak documentation --- CHANGELOG.md | 5 +++++ docs/docs/configuration/auth.md | 26 ++++++++++++++++++-------- providers/keycloak.go | 4 +--- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index edb631d6..6af45d6a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ ## Important Notes +- [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Keycloak will now use `--profile-url` if set for the userinfo endpoint + instead of `--validate-url`. `--validate-url` will still work for backwards compatibility. - [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) `--user-id-claim` option is deprecated and replaced by `--oidc-email-claim` - [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Gitlab projects needs a Gitlab application with the extra `read_api` enabled - [#849](https://github.com/oauth2-proxy/oauth2-proxy/pull/849) `/oauth2/auth` `allowed_groups` querystring parameter can be paired with the `allowed-groups` configuration option. @@ -33,6 +35,8 @@ ## Breaking Changes +- [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) In config files & envvar configs, `keycloak_group` is now the plural `keycloak_groups`. + Flag configs are still `--keycloak-group` but it can be passed multiple times. - [#911](https://github.com/oauth2-proxy/oauth2-proxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google". - [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Security changes to Google provider group authorization flow - If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately. @@ -54,6 +58,7 @@ ## Changes since v6.1.1 +- [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Migrate Keycloak to EnrichSession & support multiple groups for authorization (@NickMeves) - [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Add support for Gitlab project based authentication (@factorysh) - [#907](https://github.com/oauth2-proxy/oauth2-proxy/pull/907) Introduce alpha configuration option to enable testing of structured configuration (@JoelSpeed) - [#938](https://github.com/oauth2-proxy/oauth2-proxy/pull/938) Cleanup missed provider renaming refactor methods (@NickMeves) diff --git a/docs/docs/configuration/auth.md b/docs/docs/configuration/auth.md index 7c5bac39..f16f2e26 100644 --- a/docs/docs/configuration/auth.md +++ b/docs/docs/configuration/auth.md @@ -135,15 +135,25 @@ If you are using GitHub enterprise, make sure you set the following to the appro Make sure you set the following to the appropriate url: - -provider=keycloak - -client-id= - -client-secret= - -login-url="http(s):///auth/realms//protocol/openid-connect/auth" - -redeem-url="http(s):///auth/realms//protocol/openid-connect/token" - -validate-url="http(s):///auth/realms//protocol/openid-connect/userinfo" - -keycloak-group= + --provider=keycloak + --client-id= + --client-secret= + --login-url="http(s):///auth/realms//protocol/openid-connect/auth" + --redeem-url="http(s):///auth/realms//protocol/openid-connect/token" + --profile-url="http(s):///auth/realms//protocol/openid-connect/userinfo" + --validate-url="http(s):///auth/realms//protocol/openid-connect/userinfo" + --keycloak-group= + --keycloak-group= + +For group based authorization, the optional `--keycloak-group` (legacy) or `--allowed-group` (global standard) +flags can be used to specify which groups to limit access to. -The group management in keycloak is using a tree. If you create a group named admin in keycloak you should define the 'keycloak-group' value to /admin. +If these are unset but a `groups` mapper is set up above in step (3), the provider will still +populate the `X-Forwarded-Groups` header to your upstream server with the `groups` data in the +Keycloak userinfo endpoint response. + +The group management in keycloak is using a tree. If you create a group named admin in keycloak +you should define the 'keycloak-group' value to /admin. ### GitLab Auth Provider diff --git a/providers/keycloak.go b/providers/keycloak.go index 66eda948..03d3194c 100644 --- a/providers/keycloak.go +++ b/providers/keycloak.go @@ -80,9 +80,7 @@ func (p *KeycloakProvider) EnrichSession(ctx context.Context, s *sessions.Sessio } groups, err := json.Get("groups").StringArray() - if err != nil { - logger.Errorf("Warning: unable to extract groups from userinfo endpoint: %v", err) - } else { + if err == nil { for _, group := range groups { if group != "" { s.Groups = append(s.Groups, group) From 816d9a4566de86608c11ab70ae1161b71a325361 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Mon, 21 Dec 2020 13:46:54 -0800 Subject: [PATCH 16/28] Use a generic http.HandlerFunc in Keycloak tests --- providers/keycloak_test.go | 245 +++++++++++++++++-------------------- 1 file changed, 109 insertions(+), 136 deletions(-) diff --git a/providers/keycloak_test.go b/providers/keycloak_test.go index 7c0b457b..513be643 100644 --- a/providers/keycloak_test.go +++ b/providers/keycloak_test.go @@ -17,13 +17,6 @@ import ( const ( keycloakAccessToken = "eyJKeycloak.eyJAccess.Token" keycloakUserinfoPath = "/api/v3/user" - - // Userinfo Test Cases querystring toggles - tcUIStandard = "userinfo-standard" - tcUIFail = "userinfo-fail" - tcUISingleGroup = "userinfo-single-group" - tcUIMissingEmail = "userinfo-missing-email" - tcUIMissingGroups = "userinfo-missing-groups" ) func testKeycloakProvider(backend *httptest.Server) (*KeycloakProvider, error) { @@ -52,62 +45,6 @@ func testKeycloakProvider(backend *httptest.Server) (*KeycloakProvider, error) { return p, nil } -func testKeycloakBackend() *httptest.Server { - return httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != keycloakUserinfoPath { - w.WriteHeader(404) - } - - var err error - switch r.URL.Query().Get("testcase") { - case tcUIStandard: - w.WriteHeader(200) - _, err = w.Write([]byte(` - { - "email": "michael.bland@gsa.gov", - "groups": [ - "test-grp1", - "test-grp2" - ] - } - `)) - case tcUIFail: - w.WriteHeader(500) - case tcUISingleGroup: - w.WriteHeader(200) - _, err = w.Write([]byte(` - { - "email": "michael.bland@gsa.gov", - "groups": ["test-grp1"] - } - `)) - case tcUIMissingEmail: - w.WriteHeader(200) - _, err = w.Write([]byte(` - { - "groups": [ - "test-grp1", - "test-grp2" - ] - } - `)) - case tcUIMissingGroups: - w.WriteHeader(200) - _, err = w.Write([]byte(` - { - "email": "michael.bland@gsa.gov" - } - `)) - default: - w.WriteHeader(404) - } - if err != nil { - panic(err) - } - })) -} - var _ = Describe("Keycloak Provider Tests", func() { Context("New Provider Init", func() { It("uses defaults", func() { @@ -151,87 +88,123 @@ var _ = Describe("Keycloak Provider Tests", func() { }) }) - Context("With a test HTTP Server & Provider", func() { - var p *KeycloakProvider - var b *httptest.Server + Context("EnrichSession", func() { + type enrichSessionTableInput struct { + backendHandler http.HandlerFunc + expectedError error + expectedEmail string + expectedGroups []string + } - BeforeEach(func() { - b = testKeycloakBackend() + DescribeTable("should return expected results", + func(in enrichSessionTableInput) { + backend := httptest.NewServer(in.backendHandler) + p, err := testKeycloakProvider(backend) + Expect(err).To(BeNil()) - var err error - p, err = testKeycloakProvider(b) - Expect(err).To(BeNil()) - }) + p.ProfileURL, err = url.Parse( + fmt.Sprintf("%s%s", backend.URL, keycloakUserinfoPath), + ) + Expect(err).To(BeNil()) - AfterEach(func() { - b.Close() - }) + session := &sessions.SessionState{AccessToken: keycloakAccessToken} + err = p.EnrichSession(context.Background(), session) - Context("EnrichSession", func() { - type enrichSessionTableInput struct { - testcase string - expectedError error - expectedEmail string - expectedGroups []string - } - - DescribeTable("should return expected results", - func(in enrichSessionTableInput) { - var err error - p.ProfileURL, err = url.Parse( - fmt.Sprintf("%s%s?testcase=%s", b.URL, keycloakUserinfoPath, in.testcase), - ) + if in.expectedError != nil { + Expect(err).To(Equal(in.expectedError)) + } else { Expect(err).To(BeNil()) + } - session := &sessions.SessionState{AccessToken: keycloakAccessToken} - err = p.EnrichSession(context.Background(), session) + Expect(session.Email).To(Equal(in.expectedEmail)) - if in.expectedError != nil { - Expect(err).To(Equal(in.expectedError)) - } else { - Expect(err).To(BeNil()) - } - - Expect(session.Email).To(Equal(in.expectedEmail)) - - if in.expectedGroups != nil { - Expect(session.Groups).To(Equal(in.expectedGroups)) - } else { - Expect(session.Groups).To(BeNil()) + if in.expectedGroups != nil { + Expect(session.Groups).To(Equal(in.expectedGroups)) + } else { + Expect(session.Groups).To(BeNil()) + } + }, + Entry("email and multiple groups", enrichSessionTableInput{ + backendHandler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(200) + _, err := w.Write([]byte(` + { + "email": "michael.bland@gsa.gov", + "groups": [ + "test-grp1", + "test-grp2" + ] + } + `)) + if err != nil { + panic(err) } }, - Entry("email and multiple groups", enrichSessionTableInput{ - testcase: tcUIStandard, - expectedError: nil, - expectedEmail: "michael.bland@gsa.gov", - expectedGroups: []string{"test-grp1", "test-grp2"}, - }), - Entry("email and single group", enrichSessionTableInput{ - testcase: tcUISingleGroup, - expectedError: nil, - expectedEmail: "michael.bland@gsa.gov", - expectedGroups: []string{"test-grp1"}, - }), - Entry("email and no groups", enrichSessionTableInput{ - testcase: tcUIMissingGroups, - expectedError: nil, - expectedEmail: "michael.bland@gsa.gov", - expectedGroups: nil, - }), - Entry("missing email", enrichSessionTableInput{ - testcase: tcUIMissingEmail, - expectedError: errors.New( - "unable to extract email from userinfo endpoint: type assertion to string failed"), - expectedEmail: "", - expectedGroups: []string{"test-grp1", "test-grp2"}, - }), - Entry("request failure", enrichSessionTableInput{ - testcase: tcUIFail, - expectedError: errors.New(`unexpected status "500": `), - expectedEmail: "", - expectedGroups: nil, - }), - ) - }) + expectedError: nil, + expectedEmail: "michael.bland@gsa.gov", + expectedGroups: []string{"test-grp1", "test-grp2"}, + }), + Entry("email and single group", enrichSessionTableInput{ + backendHandler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(200) + _, err := w.Write([]byte(` + { + "email": "michael.bland@gsa.gov", + "groups": ["test-grp1"] + } + `)) + if err != nil { + panic(err) + } + }, + expectedError: nil, + expectedEmail: "michael.bland@gsa.gov", + expectedGroups: []string{"test-grp1"}, + }), + Entry("email and no groups", enrichSessionTableInput{ + backendHandler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(200) + _, err := w.Write([]byte(` + { + "email": "michael.bland@gsa.gov" + } + `)) + if err != nil { + panic(err) + } + }, + expectedError: nil, + expectedEmail: "michael.bland@gsa.gov", + expectedGroups: nil, + }), + Entry("missing email", enrichSessionTableInput{ + backendHandler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(200) + _, err := w.Write([]byte(` + { + "groups": [ + "test-grp1", + "test-grp2" + ] + } + `)) + if err != nil { + panic(err) + } + }, + expectedError: errors.New( + "unable to extract email from userinfo endpoint: type assertion to string failed"), + expectedEmail: "", + expectedGroups: []string{"test-grp1", "test-grp2"}, + }), + Entry("request failure", enrichSessionTableInput{ + backendHandler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(500) + }, + expectedError: errors.New(`unexpected status "500": `), + expectedEmail: "", + expectedGroups: nil, + }), + ) }) }) From 4b28e6886cdba86ac03f5b74c88947c238c80600 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Tue, 22 Dec 2020 21:34:15 -0800 Subject: [PATCH 17/28] Handle ValidateURL fallback for nil & empty struct cases --- providers/keycloak.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/providers/keycloak.go b/providers/keycloak.go index 03d3194c..a1e4f064 100644 --- a/providers/keycloak.go +++ b/providers/keycloak.go @@ -64,12 +64,12 @@ func NewKeycloakProvider(p *ProviderData) *KeycloakProvider { // email and groups. func (p *KeycloakProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { // Fallback to ValidateURL if ProfileURL not set for legacy compatibility - userinfoURL := p.ValidateURL.String() - if p.ProfileURL != nil { - userinfoURL = p.ProfileURL.String() + profileURL := p.ValidateURL.String() + if p.ProfileURL.String() != "" { + profileURL = p.ProfileURL.String() } - json, err := requests.New(userinfoURL). + json, err := requests.New(profileURL). WithContext(ctx). SetHeader("Authorization", "Bearer "+s.AccessToken). Do(). From 1d74a51cd70875362b09d77f3d5f9824d3d4d564 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=B0lteri=C5=9F=20Ero=C4=9Flu?= Date: Sat, 2 Jan 2021 02:23:11 +0300 Subject: [PATCH 18/28] Use X-Forwarded-{Proto,Host,Uri} on redirect as last resort (#957) --- CHANGELOG.md | 2 + docs/docs/configuration/overview.md | 69 +++++++++++++++- oauthproxy.go | 29 ++++++- oauthproxy_test.go | 120 ++++++++++++++++++++++++++++ pkg/util/util.go | 19 +++++ pkg/util/util_test.go | 26 ++++++ 6 files changed, 263 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6af45d6a..2846d28e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Keycloak will now use `--profile-url` if set for the userinfo endpoint instead of `--validate-url`. `--validate-url` will still work for backwards compatibility. +- [#957](https://github.com/oauth2-proxy/oauth2-proxy/pull/957) To use X-Forwarded-{Proto,Host,Uri} on redirect detection, `--reverse-proxy` must be `true`. - [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) `--user-id-claim` option is deprecated and replaced by `--oidc-email-claim` - [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Gitlab projects needs a Gitlab application with the extra `read_api` enabled - [#849](https://github.com/oauth2-proxy/oauth2-proxy/pull/849) `/oauth2/auth` `allowed_groups` querystring parameter can be paired with the `allowed-groups` configuration option. @@ -59,6 +60,7 @@ ## Changes since v6.1.1 - [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Migrate Keycloak to EnrichSession & support multiple groups for authorization (@NickMeves) +- [#957](https://github.com/oauth2-proxy/oauth2-proxy/pull/957) Use X-Forwarded-{Proto,Host,Uri} on redirect as last resort (@linuxgemini) - [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Add support for Gitlab project based authentication (@factorysh) - [#907](https://github.com/oauth2-proxy/oauth2-proxy/pull/907) Introduce alpha configuration option to enable testing of structured configuration (@JoelSpeed) - [#938](https://github.com/oauth2-proxy/oauth2-proxy/pull/938) Cleanup missed provider renaming refactor methods (@NickMeves) diff --git a/docs/docs/configuration/overview.md b/docs/docs/configuration/overview.md index a91c6264..457842e9 100644 --- a/docs/docs/configuration/overview.md +++ b/docs/docs/configuration/overview.md @@ -106,7 +106,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/ | `--request-logging` | bool | Log requests | true | | `--request-logging-format` | string | Template for request log lines | see [Logging Configuration](#logging-configuration) | | `--resource` | string | The resource that is protected (Azure AD only) | | -| `--reverse-proxy` | bool | are we running behind a reverse proxy, controls whether headers like X-Real-IP are accepted | false | +| `--reverse-proxy` | bool | are we running behind a reverse proxy, controls whether headers like X-Real-IP are accepted and allows X-Forwarded-{Proto,Host,Uri} headers to be used on redirect selection | false | | `--scope` | string | OAuth scope specification | | | `--session-cookie-minimal` | bool | strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only) | false | | `--session-store-type` | string | [Session data storage backend](sessions.md); redis or cookie | cookie | @@ -354,6 +354,73 @@ It is recommended to use `--session-store-type=redis` when expecting large sessi You have to substitute *name* with the actual cookie name you configured via --cookie-name parameter. If you don't set a custom cookie name the variable should be "$upstream_cookie__oauth2_proxy_1" instead of "$upstream_cookie_name_1" and the new cookie-name should be "_oauth2_proxy_1=" instead of "name_1=". +## Configuring for use with the Traefik (v2) `ForwardAuth` middleware + +**This option requires `--reverse-proxy` option to be set.** + +The [Traefik v2 `ForwardAuth` middleware](https://doc.traefik.io/traefik/middlewares/forwardauth/) allows Traefik to authenticate requests via the oauth2-proxy's `/oauth2/auth` endpoint on every request, which only returns a 202 Accepted response or a 401 Unauthorized response without proxying the whole request through. For example, on Dynamic File (YAML) Configuration: + +```yaml +http: + routers: + a-service: + rule: "Host(`a-service.example.com`)" + service: a-service-backend + middlewares: + - oauth-errors + - oauth-auth + tls: + certResolver: default + domains: + - main: "example.com" + sans: + - "*.example.com" + oauth: + rule: "Host(`a-service.example.com`, `oauth.example.com`) && PathPrefix(`/oauth2/`)" + middlewares: + - auth-headers + service: oauth-backend + tls: + certResolver: default + domains: + - main: "example.com" + sans: + - "*.example.com" + + services: + a-service-backend: + loadBalancer: + servers: + - url: http://172.16.0.2:7555 + oauth-backend: + loadBalancer: + servers: + - url: http://172.16.0.1:4180 + + middlewares: + auth-headers: + headers: + sslRedirect: true + stsSeconds: 315360000 + browserXssFilter: true + contentTypeNosniff: true + forceSTSHeader: true + sslHost: example.com + stsIncludeSubdomains: true + stsPreload: true + frameDeny: true + oauth-auth: + forwardAuth: + address: https://oauth.example.com/oauth2/auth + trustForwardHeader: true + oauth-errors: + errors: + status: + - "401-403" + service: oauth-backend + query: "/oauth2/sign_in" +``` + :::note If you set up your OAuth2 provider to rotate your client secret, you can use the `client-secret-file` option to reload the secret when it is updated. ::: diff --git a/oauthproxy.go b/oauthproxy.go index f97af98b..74ed6dc0 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -98,6 +98,7 @@ type OAuthProxy struct { SetAuthorization bool PassAuthorization bool PreferEmailToUser bool + ReverseProxy bool skipAuthPreflight bool skipJwtBearerTokens bool templates *template.Template @@ -200,6 +201,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr UserInfoPath: fmt.Sprintf("%s/userinfo", opts.ProxyPrefix), ProxyPrefix: opts.ProxyPrefix, + ReverseProxy: opts.ReverseProxy, provider: opts.GetProvider(), providerNameOverride: opts.ProviderName, sessionStore: sessionStore, @@ -578,10 +580,18 @@ func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) if req.Form.Get("rd") != "" { redirect = req.Form.Get("rd") } + // Quirk: On reverse proxies that doesn't have support for + // "X-Auth-Request-Redirect" header or dynamic header/query string + // manipulation (like Traefik v1 and v2), we can try if the header + // X-Forwarded-Host exists or not. + if redirect == "" && isForwardedRequest(req, p.ReverseProxy) { + redirect = p.getRedirectFromForwardHeaders(req) + } if !p.IsValidRedirect(redirect) { // Use RequestURI to preserve ?query redirect = req.URL.RequestURI() - if strings.HasPrefix(redirect, p.ProxyPrefix) { + + if strings.HasPrefix(redirect, fmt.Sprintf("%s/", p.ProxyPrefix)) { redirect = "/" } } @@ -589,6 +599,17 @@ func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) return } +// getRedirectFromForwardHeaders returns the redirect URL based on X-Forwarded-{Proto,Host,Uri} headers +func (p *OAuthProxy) getRedirectFromForwardHeaders(req *http.Request) string { + uri := util.GetRequestURI(req) + + if strings.HasPrefix(uri, fmt.Sprintf("%s/", p.ProxyPrefix)) { + uri = "/" + } + + return fmt.Sprintf("%s://%s%s", util.GetRequestProto(req), util.GetRequestHost(req), uri) +} + // splitHostPort separates host and port. If the port is not valid, it returns // the entire input as host, and it doesn't check the validity of the host. // Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric. @@ -686,6 +707,12 @@ func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { return false } +// isForwardedRequest is used to check if X-Forwarded-Host header exists or not +func isForwardedRequest(req *http.Request, reverseProxy bool) bool { + isForwarded := req.Host != util.GetRequestHost(req) + return isForwarded && reverseProxy +} + // See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en var noCacheHeaders = map[string]string{ "Expires": time.Unix(0, 0).Format(time.RFC1123), diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 56e046c2..572e1ec9 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1750,6 +1750,8 @@ func TestRequestSignature(t *testing.T) { func TestGetRedirect(t *testing.T) { opts := baseTestOptions() + opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com") + opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com:8443") err := validation.Validate(opts) assert.NoError(t, err) require.NotEmpty(t, opts.ProxyPrefix) @@ -1761,27 +1763,145 @@ func TestGetRedirect(t *testing.T) { tests := []struct { name string url string + headers map[string]string + reverseProxy bool expectedRedirect string }{ { name: "request outside of ProxyPrefix redirects to original URL", url: "/foo/bar", + headers: nil, + reverseProxy: false, expectedRedirect: "/foo/bar", }, { name: "request with query preserves query", url: "/foo?bar", + headers: nil, + reverseProxy: false, expectedRedirect: "/foo?bar", }, { name: "request under ProxyPrefix redirects to root", url: proxy.ProxyPrefix + "/foo/bar", + headers: nil, + reverseProxy: false, expectedRedirect: "/", }, + { + name: "proxied request outside of ProxyPrefix redirects to proxied URL", + url: "https://oauth.example.com/foo/bar", + headers: map[string]string{ + "X-Forwarded-Proto": "https", + "X-Forwarded-Host": "a-service.example.com", + "X-Forwarded-Uri": "/foo/bar", + }, + reverseProxy: true, + expectedRedirect: "https://a-service.example.com/foo/bar", + }, + { + name: "non-proxied request with spoofed proxy headers wouldn't redirect", + url: "https://oauth.example.com/foo?bar", + headers: map[string]string{ + "X-Forwarded-Proto": "https", + "X-Forwarded-Host": "a-service.example.com", + "X-Forwarded-Uri": "/foo/bar", + }, + reverseProxy: false, + expectedRedirect: "/foo?bar", + }, + { + name: "proxied request under ProxyPrefix redirects to root", + url: "https://oauth.example.com" + proxy.ProxyPrefix + "/foo/bar", + headers: map[string]string{ + "X-Forwarded-Proto": "https", + "X-Forwarded-Host": "a-service.example.com", + "X-Forwarded-Uri": proxy.ProxyPrefix + "/foo/bar", + }, + reverseProxy: true, + expectedRedirect: "https://a-service.example.com/", + }, + { + name: "proxied request with port under ProxyPrefix redirects to root", + url: "https://oauth.example.com" + proxy.ProxyPrefix + "/foo/bar", + headers: map[string]string{ + "X-Forwarded-Proto": "https", + "X-Forwarded-Host": "a-service.example.com:8443", + "X-Forwarded-Uri": proxy.ProxyPrefix + "/foo/bar", + }, + reverseProxy: true, + expectedRedirect: "https://a-service.example.com:8443/", + }, + { + name: "proxied request with missing uri header would still redirect to desired redirect", + url: "https://oauth.example.com/foo?bar", + headers: map[string]string{ + "X-Forwarded-Proto": "https", + "X-Forwarded-Host": "a-service.example.com", + }, + reverseProxy: true, + expectedRedirect: "https://a-service.example.com/foo?bar", + }, + { + name: "request with headers proxy not being set (and reverse proxy enabled) would still redirect to desired redirect", + url: "https://oauth.example.com/foo?bar", + headers: nil, + reverseProxy: true, + expectedRedirect: "/foo?bar", + }, + { + name: "proxied request with X-Auth-Request-Redirect being set outside of ProxyPrefix redirects to proxied URL", + url: "https://oauth.example.com/foo/bar", + headers: map[string]string{ + "X-Auth-Request-Redirect": "https://a-service.example.com/foo/bar", + "X-Forwarded-Proto": "", + "X-Forwarded-Host": "", + "X-Forwarded-Uri": "", + }, + reverseProxy: true, + expectedRedirect: "https://a-service.example.com/foo/bar", + }, + { + name: "proxied request with rd query string redirects to proxied URL", + url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fbar", + headers: nil, + reverseProxy: false, + expectedRedirect: "https://a-service.example.com/foo/bar", + }, + { + name: "proxied request with rd query string and all headers set (and reverse proxy not enabled) redirects to proxied URL on rd query string", + url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fjazz", + headers: map[string]string{ + "X-Auth-Request-Redirect": "https://a-service.example.com/foo/baz", + "X-Forwarded-Proto": "http", + "X-Forwarded-Host": "another-service.example.com", + "X-Forwarded-Uri": "/seasons/greetings", + }, + reverseProxy: false, + expectedRedirect: "https://a-service.example.com/foo/jazz", + }, + { + name: "proxied request with rd query string and some headers set redirects to proxied URL on rd query string", + url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fbaz", + headers: map[string]string{ + "X-Auth-Request-Redirect": "", + "X-Forwarded-Proto": "https", + "X-Forwarded-Host": "another-service.example.com", + "X-Forwarded-Uri": "/seasons/greetings", + }, + reverseProxy: true, + expectedRedirect: "https://a-service.example.com/foo/baz", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req, _ := http.NewRequest("GET", tt.url, nil) + for header, value := range tt.headers { + if value != "" { + req.Header.Add(header, value) + } + } + proxy.ReverseProxy = tt.reverseProxy redirect, err := proxy.GetRedirect(req) assert.NoError(t, err) diff --git a/pkg/util/util.go b/pkg/util/util.go index b39c1032..4eeabbf7 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -25,6 +25,15 @@ func GetCertPool(paths []string) (*x509.CertPool, error) { return pool, nil } +// GetRequestProto return the request host header or X-Forwarded-Proto if present +func GetRequestProto(req *http.Request) string { + proto := req.Header.Get("X-Forwarded-Proto") + if proto == "" { + proto = req.URL.Scheme + } + return proto +} + // GetRequestHost return the request host header or X-Forwarded-Host if present func GetRequestHost(req *http.Request) string { host := req.Header.Get("X-Forwarded-Host") @@ -33,3 +42,13 @@ func GetRequestHost(req *http.Request) string { } return host } + +// GetRequestURI return the request host header or X-Forwarded-Uri if present +func GetRequestURI(req *http.Request) string { + uri := req.Header.Get("X-Forwarded-Uri") + if uri == "" { + // Use RequestURI to preserve ?query + uri = req.URL.RequestURI() + } + return uri +} diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go index c1e3d688..d032025e 100644 --- a/pkg/util/util_test.go +++ b/pkg/util/util_test.go @@ -110,3 +110,29 @@ func TestGetRequestHost(t *testing.T) { extHost := GetRequestHost(proxyReq) g.Expect(extHost).To(Equal("external.example.com")) } + +func TestGetRequestProto(t *testing.T) { + g := NewWithT(t) + + req := httptest.NewRequest("GET", "https://example.com", nil) + proto := GetRequestProto(req) + g.Expect(proto).To(Equal("https")) + + proxyReq := httptest.NewRequest("GET", "https://internal.example.com", nil) + proxyReq.Header.Add("X-Forwarded-Proto", "http") + extProto := GetRequestProto(proxyReq) + g.Expect(extProto).To(Equal("http")) +} + +func TestGetRequestURI(t *testing.T) { + g := NewWithT(t) + + req := httptest.NewRequest("GET", "https://example.com/ping", nil) + uri := GetRequestURI(req) + g.Expect(uri).To(Equal("/ping")) + + proxyReq := httptest.NewRequest("GET", "http://internal.example.com/bong", nil) + proxyReq.Header.Add("X-Forwarded-Uri", "/ping") + extURI := GetRequestURI(proxyReq) + g.Expect(extURI).To(Equal("/ping")) +} From 597ffeb1217058afd71182e8cec1de670f8b6998 Mon Sep 17 00:00:00 2001 From: Ilia Pertsev Date: Tue, 5 Jan 2021 04:21:17 +0300 Subject: [PATCH 19/28] Fix joined cookie name for those containing underline in the suffix (#970) * properly handle splitted cookies with names ending with _ * test update * provide cookieName into joinCookies instead of processing the suffix * changelog update * test update --- CHANGELOG.md | 1 + pkg/sessions/cookie/session_store.go | 7 ++-- pkg/sessions/cookie/session_store_test.go | 51 ++++++++++++++++++++++- 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2846d28e..2272fb47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,7 @@ ## Changes since v6.1.1 +- [#970](https://github.com/oauth2-proxy/oauth2-proxy/pull/970) Fix joined cookie name for those containing underline in the suffix (@peppered) - [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Migrate Keycloak to EnrichSession & support multiple groups for authorization (@NickMeves) - [#957](https://github.com/oauth2-proxy/oauth2-proxy/pull/957) Use X-Forwarded-{Proto,Host,Uri} on redirect as last resort (@linuxgemini) - [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Add support for Gitlab project based authentication (@factorysh) diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index 4a546cfc..ce51ed07 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "regexp" - "strings" "time" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" @@ -220,12 +219,12 @@ func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) { if len(cookies) == 0 { return nil, fmt.Errorf("could not find cookie %s", cookieName) } - return joinCookies(cookies) + return joinCookies(cookies, cookieName) } // joinCookies takes a slice of cookies from the request and reconstructs the // full session cookie -func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) { +func joinCookies(cookies []*http.Cookie, cookieName string) (*http.Cookie, error) { if len(cookies) == 0 { return nil, fmt.Errorf("list of cookies must be > 0") } @@ -236,7 +235,7 @@ func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) { for i := 1; i < len(cookies); i++ { c.Value += cookies[i].Value } - c.Name = strings.TrimRight(c.Name, "_0") + c.Name = cookieName return c, nil } diff --git a/pkg/sessions/cookie/session_store_test.go b/pkg/sessions/cookie/session_store_test.go index 5ef9eff1..bf9dda14 100644 --- a/pkg/sessions/cookie/session_store_test.go +++ b/pkg/sessions/cookie/session_store_test.go @@ -154,9 +154,58 @@ func Test_splitCookie_joinCookies(t *testing.T) { Value: value, } splitCookies := splitCookie(cookie) - joinedCookie, err := joinCookies(splitCookies) + joinedCookie, err := joinCookies(splitCookies, cookie.Name) assert.NoError(t, err) assert.Equal(t, *cookie, *joinedCookie) }) } } + +func Test_joinCookies_withUnderlineSuffix(t *testing.T) { + testCases := map[string]struct { + CookieName string + SplitOrder []int + }{ + "Ascending order split with \"_\" suffix": { + CookieName: "_cookie_name_", + SplitOrder: []int{0, 1, 2, 3, 4}, + }, + "Descending order split with \"_\" suffix": { + CookieName: "_cookie_name_", + SplitOrder: []int{4, 3, 2, 1, 0}, + }, + "Arbitrary order split with \"_\" suffix": { + CookieName: "_cookie_name_", + SplitOrder: []int{3, 1, 2, 0, 4}, + }, + "Arbitrary order split with \"_0\" suffix": { + CookieName: "_cookie_name_0", + SplitOrder: []int{1, 3, 0, 2, 4}, + }, + "Arbitrary order split with \"_1\" suffix": { + CookieName: "_cookie_name_1", + SplitOrder: []int{4, 1, 3, 0, 2}, + }, + "Arbitrary order split with \"__\" suffix": { + CookieName: "_cookie_name__", + SplitOrder: []int{1, 0, 4, 3, 2}, + }, + } + + for testName, testCase := range testCases { + t.Run(testName, func(t *testing.T) { + cookieName := testCase.CookieName + var splitCookies []*http.Cookie + for _, splitSuffix := range testCase.SplitOrder { + cookie := &http.Cookie{ + Name: splitCookieName(cookieName, splitSuffix), + Value: strings.Repeat("v", 1000), + } + splitCookies = append(splitCookies, cookie) + } + joinedCookie, err := joinCookies(splitCookies, cookieName) + assert.NoError(t, err) + assert.Equal(t, cookieName, joinedCookie.Name) + }) + } +} From d08b9b7cc49c0aed7422aa76f7e5a1218587d559 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sun, 10 Jan 2021 10:56:01 -0800 Subject: [PATCH 20/28] Add NickMeves to MAINTAINERS --- MAINTAINERS | 1 + 1 file changed, 1 insertion(+) diff --git a/MAINTAINERS b/MAINTAINERS index 1642e741..784eef4d 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -1,2 +1,3 @@ Joel Speed (@JoelSpeed) Henry Jenkins (@steakunderscore) +Nick Meves (@NickMeves) From 81bf1ef8ceb8be4b4a86007a042157ac0eeb2a28 Mon Sep 17 00:00:00 2001 From: Nikolai Prokoschenko Date: Tue, 12 Jan 2021 16:40:14 +0100 Subject: [PATCH 21/28] Adapt isAjax to support mimetype lists Fixes #988 --- CHANGELOG.md | 1 + oauthproxy.go | 14 +++++++++++--- oauthproxy_test.go | 7 +++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2272fb47..fc41907b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -100,6 +100,7 @@ - [#750](https://github.com/oauth2-proxy/oauth2-proxy/pull/750) ci: Migrate to Github Actions (@shinebayar-g) - [#829](https://github.com/oauth2-proxy/oauth2-proxy/pull/820) Rename test directory to testdata (@johejo) - [#819](https://github.com/oauth2-proxy/oauth2-proxy/pull/819) Improve CI (@johejo) +- [#989](https://github.com/oauth2-proxy/oauth2-proxy/pull/989) Adapt isAjax to support mimetype lists (@rassie) # v6.1.1 diff --git a/oauthproxy.go b/oauthproxy.go index 74ed6dc0..cfba6934 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -1111,9 +1111,17 @@ func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Req func isAjax(req *http.Request) bool { acceptValues := req.Header.Values("Accept") const ajaxReq = applicationJSON - for _, v := range acceptValues { - if v == ajaxReq { - return true + // Iterate over multiple Accept headers, i.e. + // Accept: application/json + // Accept: text/plain + for _, mimeTypes := range acceptValues { + // Iterate over multiple mimetypes in a single header, i.e. + // Accept: application/json, text/plain, */* + for _, mimeType := range strings.Split(mimeTypes, ",") { + mimeType = strings.TrimSpace(mimeType) + if mimeType == ajaxReq { + return true + } } } return false diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 572e1ec9..52ffa2b3 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1970,6 +1970,13 @@ func TestAjaxUnauthorizedRequest2(t *testing.T) { testAjaxUnauthorizedRequest(t, header) } +func TestAjaxUnauthorizedRequestAccept1(t *testing.T) { + header := make(http.Header) + header.Add("Accept", "application/json, text/plain, */*") + + testAjaxUnauthorizedRequest(t, header) +} + func TestAjaxForbiddendRequest(t *testing.T) { test, err := newAjaxRequestTest() if err != nil { From e50e6ed373f757cc7ed581e94b1f42e0509026db Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sat, 16 Jan 2021 17:34:53 +0000 Subject: [PATCH 22/28] Add Security Policy --- CHANGELOG.md | 1 + SECURITY.md | 3 ++ docs/docs/community/security.md | 49 +++++++++++++++++++ docs/sidebars.js | 6 +++ .../version-6.1.x/community/security.md | 49 +++++++++++++++++++ .../version-6.1.x-sidebars.json | 11 +++++ 6 files changed, 119 insertions(+) create mode 100644 SECURITY.md create mode 100644 docs/docs/community/security.md create mode 100644 docs/versioned_docs/version-6.1.x/community/security.md diff --git a/CHANGELOG.md b/CHANGELOG.md index fc41907b..0e38f0cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,7 @@ ## Changes since v6.1.1 +- [#995](https://github.com/oauth2-proxy/oauth2-proxy/pull/995) Add Security Policy (@JoelSpeed) - [#970](https://github.com/oauth2-proxy/oauth2-proxy/pull/970) Fix joined cookie name for those containing underline in the suffix (@peppered) - [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Migrate Keycloak to EnrichSession & support multiple groups for authorization (@NickMeves) - [#957](https://github.com/oauth2-proxy/oauth2-proxy/pull/957) Use X-Forwarded-{Proto,Host,Uri} on redirect as last resort (@linuxgemini) diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..0d6d27e8 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,3 @@ +# Security Disclosures + +Please see [our community docs](https://oauth2-proxy.github.io/oauth2-proxy/docs/community/security) for our security policy. diff --git a/docs/docs/community/security.md b/docs/docs/community/security.md new file mode 100644 index 00000000..c24b57d9 --- /dev/null +++ b/docs/docs/community/security.md @@ -0,0 +1,49 @@ +--- +id: security +title: Security +--- + +:::note +OAuth2 Proxy is a community project. +Maintainers do not work on this project full time, and as such, +while we endeavour to respond to disclosures as quickly as possible, +this may take longer than in projects with corporate sponsorship. +::: + +## Security Disclosures + +:::important +If you believe you have found a vulnerability within OAuth2 Proxy or any of its +dependencies, please do NOT open an issue or PR on GitHub, please do NOT post +any details publicly. +::: + +Security disclosures MUST be done in private. +If you have found an issue that you would like to bring to the attention of the +maintenance team for OAuth2 Proxy, please compose an email and send it to the +list of maintainers in our [MAINTAINERS](https://github.com/oauth2-proxy/oauth2-proxy/blob/master/MAINTAINERS) file. + +Please include as much detail as possible. +Ideally, your disclosure should include: +- A reproducible case that can be used to demonstrate the exploit +- How you discovered this vulnerability +- A potential fix for the issue (if you have thought of one) +- Versions affected (if not present in master) +- Your GitHub ID + +### How will we respond to disclosures? + +We use [GitHub Security Advisories](https://docs.github.com/en/github/managing-security-vulnerabilities/about-github-security-advisories) +to privately discuss fixes for disclosed vulnerabilities. +If you include a GitHub ID with your disclosure we will add you as a collaborator +for the advisory so that you can join the discussion and validate any fixes +we may propose. + +For minor issues and previously disclosed vulnerabilities (typically for +dependencies), we may use regular PRs for fixes and forego the security advisory. + +Once a fix has been agreed upon, we will merge the fix and create a new release. +If we have multiple security issues in flight simultaneously, we may delay +merging fixes until all patches are ready. +We may also backport the fix to previous releases, +but this will be at the discretion of the maintainers. diff --git a/docs/sidebars.js b/docs/sidebars.js index f25dc619..4e5bd9a4 100644 --- a/docs/sidebars.js +++ b/docs/sidebars.js @@ -20,5 +20,11 @@ module.exports = { collapsed: false, items: ['features/endpoints', 'features/request_signatures'], }, + { + type: 'category', + label: 'Community', + collapsed: false, + items: ['community/security'], + }, ], }; diff --git a/docs/versioned_docs/version-6.1.x/community/security.md b/docs/versioned_docs/version-6.1.x/community/security.md new file mode 100644 index 00000000..9c406cf8 --- /dev/null +++ b/docs/versioned_docs/version-6.1.x/community/security.md @@ -0,0 +1,49 @@ +--- +id: security +title: Security +--- + +:::note +OAuth2 Proxy is a community project. +Maintainers do not work on this project full time, and as such, +while we endeavour to respond to disclosures as quickly as possible, +this may take longer than in projects with corporate sponsorship. +::: + +## Security Disclosures + +:::important +If you believe you have found a vulnerability within OAuth2 Proxy or any of its +dependencies, please do NOT open an issue or PR on GitHub, please do NOT post any +details publicly. +::: + +Security disclosures MUST be done in private. +If you have found an issue that you would like to bring to the attention of the +maintenance team for OAuth2 Proxy, please compose an email and send it to the +list of maintainers in our [MAINTAINERS](https://github.com/oauth2-proxy/oauth2-proxy/blob/master/MAINTAINERS) file. + +Please include as much detail as possible. +Ideally, your disclosure should include: +- A reproducible case that can be used to demonstrate the exploit +- How you discovered this vulnerability +- A potential fix for the issue (if you have thought of one) +- Versions affected (if not present in master) +- Your GitHub ID + +### How will we respond to disclosures? + +We use [GitHub Security Advisories](https://docs.github.com/en/github/managing-security-vulnerabilities/about-github-security-advisories) +to privately discuss fixes for disclosed vulnerabilities. +If you include a GitHub ID with your disclosure we will add you as a collaborator +for the advisory so that you can join the discussion and validate any fixes +we may propose. + +For minor issues and previously disclosed vulnerabilities (typically for +dependencies), we may use regular PRs for fixes and forego the security advisory. + +Once a fix has been agreed upon, we will merge the fix and create a new release. +If we have multiple security issues in flight simultaneously, we may delay +merging fixes until all patches are ready. +We may also backport the fix to previous releases, +but this will be at the discretion of the maintainers. diff --git a/docs/versioned_sidebars/version-6.1.x-sidebars.json b/docs/versioned_sidebars/version-6.1.x-sidebars.json index d552f4a3..1e173ada 100644 --- a/docs/versioned_sidebars/version-6.1.x-sidebars.json +++ b/docs/versioned_sidebars/version-6.1.x-sidebars.json @@ -45,6 +45,17 @@ "id": "version-6.1.x/features/request_signatures" } ] + }, + { + "collapsed": false, + "type": "category", + "label": "Community", + "items": [ + { + "type": "doc", + "id": "version-6.1.x/community/security" + } + ] } ] } From b625de94904f33130714d92b3a650bd0cef569bf Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Wed, 23 Dec 2020 17:42:02 -0800 Subject: [PATCH 23/28] Track the ReverseProxy option in the request Scope This allows for proper handling of reverse proxy based headers throughout the lifecycle of a request. --- oauthproxy.go | 2 +- pkg/apis/middleware/scope.go | 4 ++++ pkg/middleware/scope.go | 26 +++++++++++++------------- pkg/util/util.go | 14 +++++++++++--- 4 files changed, 29 insertions(+), 17 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index cfba6934..b51b9bea 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -231,7 +231,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr // the OAuth2 Proxy authentication logic kicks in. // For example forcing HTTPS or health checks. func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { - chain := alice.New(middleware.NewScope()) + chain := alice.New(middleware.NewScope(opts)) if opts.ForceHTTPS { _, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress) diff --git a/pkg/apis/middleware/scope.go b/pkg/apis/middleware/scope.go index 37f6f336..cb6fe4b8 100644 --- a/pkg/apis/middleware/scope.go +++ b/pkg/apis/middleware/scope.go @@ -8,6 +8,10 @@ import ( // The RequestScope is used to pass information between different middlewares // within the chain. type RequestScope struct { + // ReverseProxy tracks whether OAuth2-Proxy is operating in reverse proxy + // mode and if request `X-Forwarded-*` headers should be trusted + ReverseProxy bool + // Session details the authenticated users information (if it exists). Session *sessions.SessionState diff --git a/pkg/middleware/scope.go b/pkg/middleware/scope.go index 88719310..6485cc4f 100644 --- a/pkg/middleware/scope.go +++ b/pkg/middleware/scope.go @@ -6,26 +6,26 @@ import ( "github.com/justinas/alice" middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" ) type scopeKey string -// requestScopeKey uses a typed string to reduce likelihood of clasing +// requestScopeKey uses a typed string to reduce likelihood of clashing // with other context keys const requestScopeKey scopeKey = "request-scope" -func NewScope() alice.Constructor { - return addScope -} - -// addScope injects a new request scope into the request context. -func addScope(next http.Handler) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - scope := &middlewareapi.RequestScope{} - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - requestWithScope := req.WithContext(contextWithScope) - next.ServeHTTP(rw, requestWithScope) - }) +func NewScope(opts *options.Options) alice.Constructor { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := &middlewareapi.RequestScope{ + ReverseProxy: opts.ReverseProxy, + } + contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) + requestWithScope := req.WithContext(contextWithScope) + next.ServeHTTP(rw, requestWithScope) + }) + } } // GetRequestScope returns the current request scope from the given request diff --git a/pkg/util/util.go b/pkg/util/util.go index 4eeabbf7..452e14f1 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -5,6 +5,8 @@ import ( "fmt" "io/ioutil" "net/http" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" ) func GetCertPool(paths []string) (*x509.CertPool, error) { @@ -28,16 +30,17 @@ func GetCertPool(paths []string) (*x509.CertPool, error) { // GetRequestProto return the request host header or X-Forwarded-Proto if present func GetRequestProto(req *http.Request) string { proto := req.Header.Get("X-Forwarded-Proto") - if proto == "" { + if !isProxied(req) || proto == "" { proto = req.URL.Scheme } return proto } // GetRequestHost return the request host header or X-Forwarded-Host if present +// and reverse proxy mode is enabled. func GetRequestHost(req *http.Request) string { host := req.Header.Get("X-Forwarded-Host") - if host == "" { + if !isProxied(req) || host == "" { host = req.Host } return host @@ -46,9 +49,14 @@ func GetRequestHost(req *http.Request) string { // GetRequestURI return the request host header or X-Forwarded-Uri if present func GetRequestURI(req *http.Request) string { uri := req.Header.Get("X-Forwarded-Uri") - if uri == "" { + if !isProxied(req) || uri == "" { // Use RequestURI to preserve ?query uri = req.URL.RequestURI() } return uri } + +func isProxied(req *http.Request) bool { + scope := middleware.GetRequestScope(req) + return scope.ReverseProxy +} From 6fb3274ca3649d8e5b263e0742879002608f1455 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 2 Jan 2021 13:16:01 -0800 Subject: [PATCH 24/28] Refactor organization of scope aware request utils Reorganized the structure of the Request Utils due to their widespread use resulting in circular imports issues (mostly because of middleware & logger). --- pkg/apis/middleware/middleware_suite_test.go | 19 +++ pkg/apis/middleware/scope.go | 25 ++++ pkg/apis/middleware/scope_test.go | 56 ++++++++ pkg/cookies/cookies.go | 8 +- pkg/logger/logger.go | 6 +- pkg/middleware/basic_session.go | 3 +- pkg/middleware/basic_session_test.go | 6 +- pkg/middleware/headers.go | 5 +- pkg/middleware/headers_test.go | 7 +- pkg/middleware/jwt_session.go | 2 +- pkg/middleware/jwt_session_test.go | 5 +- pkg/middleware/scope.go | 27 +--- pkg/middleware/scope_test.go | 89 +++++-------- pkg/middleware/stored_session.go | 3 +- pkg/middleware/stored_session_test.go | 5 +- pkg/requests/util/util.go | 48 +++++++ pkg/requests/util/util_suite_test.go | 19 +++ pkg/requests/util/util_test.go | 131 +++++++++++++++++++ pkg/util/util.go | 37 ------ pkg/util/util_test.go | 41 ------ 20 files changed, 357 insertions(+), 185 deletions(-) create mode 100644 pkg/apis/middleware/middleware_suite_test.go create mode 100644 pkg/apis/middleware/scope_test.go create mode 100644 pkg/requests/util/util.go create mode 100644 pkg/requests/util/util_suite_test.go create mode 100644 pkg/requests/util/util_test.go diff --git a/pkg/apis/middleware/middleware_suite_test.go b/pkg/apis/middleware/middleware_suite_test.go new file mode 100644 index 00000000..f2f48cfd --- /dev/null +++ b/pkg/apis/middleware/middleware_suite_test.go @@ -0,0 +1,19 @@ +package middleware_test + +import ( + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +// TestMiddlewareSuite and related tests are in a *_test package +// to prevent circular imports with the `logger` package which uses +// this functionality +func TestMiddlewareSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Middleware API") +} diff --git a/pkg/apis/middleware/scope.go b/pkg/apis/middleware/scope.go index cb6fe4b8..c54a33d1 100644 --- a/pkg/apis/middleware/scope.go +++ b/pkg/apis/middleware/scope.go @@ -1,9 +1,18 @@ package middleware import ( + "context" + "net/http" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" ) +type scopeKey string + +// RequestScopeKey uses a typed string to reduce likelihood of clashing +// with other context keys +const RequestScopeKey scopeKey = "request-scope" + // RequestScope contains information regarding the request that is being made. // The RequestScope is used to pass information between different middlewares // within the chain. @@ -26,3 +35,19 @@ type RequestScope struct { // it was loaded or not. SessionRevalidated bool } + +// GetRequestScope returns the current request scope from the given request +func GetRequestScope(req *http.Request) *RequestScope { + scope := req.Context().Value(RequestScopeKey) + if scope == nil { + return nil + } + + return scope.(*RequestScope) +} + +// AddRequestScope adds a RequestScope to a request +func AddRequestScope(req *http.Request, scope *RequestScope) *http.Request { + ctx := context.WithValue(req.Context(), RequestScopeKey, scope) + return req.WithContext(ctx) +} diff --git a/pkg/apis/middleware/scope_test.go b/pkg/apis/middleware/scope_test.go new file mode 100644 index 00000000..355365bf --- /dev/null +++ b/pkg/apis/middleware/scope_test.go @@ -0,0 +1,56 @@ +package middleware_test + +import ( + "net/http" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Scope Suite", func() { + Context("GetRequestScope", func() { + var request *http.Request + + BeforeEach(func() { + var err error + request, err = http.NewRequest("", "http://127.0.0.1/", nil) + Expect(err).ToNot(HaveOccurred()) + }) + + Context("with a scope", func() { + var scope *middleware.RequestScope + + BeforeEach(func() { + scope = &middleware.RequestScope{} + request = middleware.AddRequestScope(request, scope) + }) + + It("returns the scope", func() { + s := middleware.GetRequestScope(request) + Expect(s).ToNot(BeNil()) + Expect(s).To(Equal(scope)) + }) + + Context("if the scope is then modified", func() { + BeforeEach(func() { + Expect(scope.SaveSession).To(BeFalse()) + scope.SaveSession = true + }) + + It("returns the updated session", func() { + s := middleware.GetRequestScope(request) + Expect(s).ToNot(BeNil()) + Expect(s).To(Equal(scope)) + Expect(s.SaveSession).To(BeTrue()) + }) + }) + }) + + Context("without a scope", func() { + It("returns nil", func() { + Expect(middleware.GetRequestScope(request)).To(BeNil()) + }) + }) + }) +}) diff --git a/pkg/cookies/cookies.go b/pkg/cookies/cookies.go index 9b6dc03d..c590de38 100644 --- a/pkg/cookies/cookies.go +++ b/pkg/cookies/cookies.go @@ -9,14 +9,14 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" + requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util" ) // MakeCookie constructs a cookie from the given parameters, // discovering the domain from the request if not specified. func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie { if domain != "" { - host := util.GetRequestHost(req) + host := requestutil.GetRequestHost(req) if h, _, err := net.SplitHostPort(host); err == nil { host = h } @@ -48,7 +48,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO // If nothing matches, create the cookie with the shortest domain defaultDomain := "" if len(cookieOpts.Domains) > 0 { - logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", util.GetRequestHost(req), strings.Join(cookieOpts.Domains, ",")) + logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", requestutil.GetRequestHost(req), strings.Join(cookieOpts.Domains, ",")) defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1] } return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite)) @@ -57,7 +57,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO // GetCookieDomain returns the correct cookie domain given a list of domains // by checking the X-Fowarded-Host and host header of an an http request func GetCookieDomain(req *http.Request, cookieDomains []string) string { - host := util.GetRequestHost(req) + host := requestutil.GetRequestHost(req) for _, domain := range cookieDomains { if strings.HasSuffix(host, domain) { return domain diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 23696765..86ad720e 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -12,7 +12,7 @@ import ( "text/template" "time" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" + requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util" ) // AuthStatus defines the different types of auth logging that occur @@ -197,7 +197,7 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu err := l.authTemplate.Execute(l.writer, authLogMessageData{ Client: client, - Host: util.GetRequestHost(req), + Host: requestutil.GetRequestHost(req), Protocol: req.Proto, RequestMethod: req.Method, Timestamp: FormatTimestamp(now), @@ -251,7 +251,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url. err := l.reqTemplate.Execute(l.writer, reqLogMessageData{ Client: client, - Host: util.GetRequestHost(req), + Host: requestutil.GetRequestHost(req), Protocol: req.Proto, RequestDuration: fmt.Sprintf("%0.3f", duration), RequestMethod: req.Method, diff --git a/pkg/middleware/basic_session.go b/pkg/middleware/basic_session.go index 5a7b77f9..7de1bf2b 100644 --- a/pkg/middleware/basic_session.go +++ b/pkg/middleware/basic_session.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/justinas/alice" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -23,7 +24,7 @@ func NewBasicAuthSessionLoader(validator basic.Validator) alice.Constructor { // If a session was loaded by a previous handler, it will not be replaced. func loadBasicAuthSession(validator basic.Validator, next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - scope := GetRequestScope(req) + scope := middlewareapi.GetRequestScope(req) // If scope is nil, this will panic. // A scope should always be injected before this handler is called. if scope.Session != nil { diff --git a/pkg/middleware/basic_session_test.go b/pkg/middleware/basic_session_test.go index 35e4f804..14c49c43 100644 --- a/pkg/middleware/basic_session_test.go +++ b/pkg/middleware/basic_session_test.go @@ -1,7 +1,6 @@ package middleware import ( - "context" "fmt" "net/http" "net/http/httptest" @@ -40,8 +39,7 @@ var _ = Describe("Basic Auth Session Suite", func() { // Set up the request with the authorization header and a request scope req := httptest.NewRequest("", "/", nil) req.Header.Set("Authorization", in.authorizationHeader) - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - req = req.WithContext(contextWithScope) + req = middlewareapi.AddRequestScope(req, scope) rw := httptest.NewRecorder() @@ -57,7 +55,7 @@ var _ = Describe("Basic Auth Session Suite", func() { // from the scope var gotSession *sessionsapi.SessionState handler := NewBasicAuthSessionLoader(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session + gotSession = middlewareapi.GetRequestScope(r).Session })) handler.ServeHTTP(rw, req) diff --git a/pkg/middleware/headers.go b/pkg/middleware/headers.go index 6786c2eb..b79b547b 100644 --- a/pkg/middleware/headers.go +++ b/pkg/middleware/headers.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/justinas/alice" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/header" ) @@ -61,7 +62,7 @@ func newRequestHeaderInjector(headers []options.Header) (alice.Constructor, erro func injectRequestHeaders(injector header.Injector, next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - scope := GetRequestScope(req) + scope := middlewareapi.GetRequestScope(req) // If scope is nil, this will panic. // A scope should always be injected before this handler is called. @@ -92,7 +93,7 @@ func newResponseHeaderInjector(headers []options.Header) (alice.Constructor, err func injectResponseHeaders(injector header.Injector, next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - scope := GetRequestScope(req) + scope := middlewareapi.GetRequestScope(req) // If scope is nil, this will panic. // A scope should always be injected before this handler is called. diff --git a/pkg/middleware/headers_test.go b/pkg/middleware/headers_test.go index 15006b1d..a9c6d73e 100644 --- a/pkg/middleware/headers_test.go +++ b/pkg/middleware/headers_test.go @@ -1,7 +1,6 @@ package middleware import ( - "context" "encoding/base64" "net/http" "net/http/httptest" @@ -31,8 +30,7 @@ var _ = Describe("Headers Suite", func() { // Set up the request with a request scope req := httptest.NewRequest("", "/", nil) - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - req = req.WithContext(contextWithScope) + req = middlewareapi.AddRequestScope(req, scope) req.Header = in.initialHeaders.Clone() rw := httptest.NewRecorder() @@ -218,8 +216,7 @@ var _ = Describe("Headers Suite", func() { // Set up the request with a request scope req := httptest.NewRequest("", "/", nil) - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - req = req.WithContext(contextWithScope) + req = middlewareapi.AddRequestScope(req, scope) rw := httptest.NewRecorder() for key, values := range in.initialHeaders { diff --git a/pkg/middleware/jwt_session.go b/pkg/middleware/jwt_session.go index 0510c72a..78ef5400 100644 --- a/pkg/middleware/jwt_session.go +++ b/pkg/middleware/jwt_session.go @@ -37,7 +37,7 @@ type jwtSessionLoader struct { // If a session was loaded by a previous handler, it will not be replaced. func (j *jwtSessionLoader) loadSession(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - scope := GetRequestScope(req) + scope := middlewareapi.GetRequestScope(req) // If scope is nil, this will panic. // A scope should always be injected before this handler is called. if scope.Session != nil { diff --git a/pkg/middleware/jwt_session_test.go b/pkg/middleware/jwt_session_test.go index cd34c5ad..7786d00a 100644 --- a/pkg/middleware/jwt_session_test.go +++ b/pkg/middleware/jwt_session_test.go @@ -103,8 +103,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` // Set up the request with the authorization header and a request scope req := httptest.NewRequest("", "/", nil) req.Header.Set("Authorization", in.authorizationHeader) - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - req = req.WithContext(contextWithScope) + req = middlewareapi.AddRequestScope(req, scope) rw := httptest.NewRecorder() @@ -116,7 +115,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` // from the scope var gotSession *sessionsapi.SessionState handler := NewJwtSessionLoader(sessionLoaders)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session + gotSession = middlewareapi.GetRequestScope(r).Session })) handler.ServeHTTP(rw, req) diff --git a/pkg/middleware/scope.go b/pkg/middleware/scope.go index 6485cc4f..9218faa0 100644 --- a/pkg/middleware/scope.go +++ b/pkg/middleware/scope.go @@ -1,39 +1,20 @@ package middleware import ( - "context" "net/http" "github.com/justinas/alice" middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" ) -type scopeKey string - -// requestScopeKey uses a typed string to reduce likelihood of clashing -// with other context keys -const requestScopeKey scopeKey = "request-scope" - -func NewScope(opts *options.Options) alice.Constructor { +func NewScope(reverseProxy bool) alice.Constructor { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { scope := &middlewareapi.RequestScope{ - ReverseProxy: opts.ReverseProxy, + ReverseProxy: reverseProxy, } - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - requestWithScope := req.WithContext(contextWithScope) - next.ServeHTTP(rw, requestWithScope) + req = middlewareapi.AddRequestScope(req, scope) + next.ServeHTTP(rw, req) }) } } - -// GetRequestScope returns the current request scope from the given request -func GetRequestScope(req *http.Request) *middlewareapi.RequestScope { - scope := req.Context().Value(requestScopeKey) - if scope == nil { - return nil - } - - return scope.(*middlewareapi.RequestScope) -} diff --git a/pkg/middleware/scope_test.go b/pkg/middleware/scope_test.go index e9533a8d..3432d148 100644 --- a/pkg/middleware/scope_test.go +++ b/pkg/middleware/scope_test.go @@ -1,7 +1,6 @@ package middleware import ( - "context" "net/http" "net/http/httptest" @@ -21,73 +20,49 @@ var _ = Describe("Scope Suite", func() { Expect(err).ToNot(HaveOccurred()) rw = httptest.NewRecorder() - - handler := NewScope()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - nextRequest = r - w.WriteHeader(200) - })) - handler.ServeHTTP(rw, request) }) - It("does not add a scope to the original request", func() { - Expect(request.Context().Value(requestScopeKey)).To(BeNil()) - }) - - It("cannot load a scope from the original request using GetRequestScope", func() { - Expect(GetRequestScope(request)).To(BeNil()) - }) - - It("adds a scope to the request for the next handler", func() { - Expect(nextRequest.Context().Value(requestScopeKey)).ToNot(BeNil()) - }) - - It("can load a scope from the next handler's request using GetRequestScope", func() { - Expect(GetRequestScope(nextRequest)).ToNot(BeNil()) - }) - }) - - Context("GetRequestScope", func() { - var request *http.Request - - BeforeEach(func() { - var err error - request, err = http.NewRequest("", "http://127.0.0.1/", nil) - Expect(err).ToNot(HaveOccurred()) - }) - - Context("with a scope", func() { - var scope *middlewareapi.RequestScope - + Context("ReverseProxy is false", func() { BeforeEach(func() { - scope = &middlewareapi.RequestScope{} - contextWithScope := context.WithValue(request.Context(), requestScopeKey, scope) - request = request.WithContext(contextWithScope) + handler := NewScope(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextRequest = r + w.WriteHeader(200) + })) + handler.ServeHTTP(rw, request) }) - It("returns the scope", func() { - s := GetRequestScope(request) - Expect(s).ToNot(BeNil()) - Expect(s).To(Equal(scope)) + It("does not add a scope to the original request", func() { + Expect(request.Context().Value(middlewareapi.RequestScopeKey)).To(BeNil()) }) - Context("if the scope is then modified", func() { - BeforeEach(func() { - Expect(scope.SaveSession).To(BeFalse()) - scope.SaveSession = true - }) + It("cannot load a scope from the original request using GetRequestScope", func() { + Expect(middlewareapi.GetRequestScope(request)).To(BeNil()) + }) - It("returns the updated session", func() { - s := GetRequestScope(request) - Expect(s).ToNot(BeNil()) - Expect(s).To(Equal(scope)) - Expect(s.SaveSession).To(BeTrue()) - }) + It("adds a scope to the request for the next handler", func() { + Expect(nextRequest.Context().Value(middlewareapi.RequestScopeKey)).ToNot(BeNil()) + }) + + It("can load a scope from the next handler's request using GetRequestScope", func() { + scope := middlewareapi.GetRequestScope(nextRequest) + Expect(scope).ToNot(BeNil()) + Expect(scope.ReverseProxy).To(BeFalse()) }) }) - Context("without a scope", func() { - It("returns nil", func() { - Expect(GetRequestScope(request)).To(BeNil()) + Context("ReverseProxy is true", func() { + BeforeEach(func() { + handler := NewScope(true)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextRequest = r + w.WriteHeader(200) + })) + handler.ServeHTTP(rw, request) + }) + + It("return a scope where the ReverseProxy field is true", func() { + scope := middlewareapi.GetRequestScope(nextRequest) + Expect(scope).ToNot(BeNil()) + Expect(scope.ReverseProxy).To(BeTrue()) }) }) }) diff --git a/pkg/middleware/stored_session.go b/pkg/middleware/stored_session.go index 6d86e613..1bd0a9a4 100644 --- a/pkg/middleware/stored_session.go +++ b/pkg/middleware/stored_session.go @@ -8,6 +8,7 @@ import ( "time" "github.com/justinas/alice" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" ) @@ -59,7 +60,7 @@ type storedSessionLoader struct { // If a session was loader by a previous handler, it will not be replaced. func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - scope := GetRequestScope(req) + scope := middlewareapi.GetRequestScope(req) // If scope is nil, this will panic. // A scope should always be injected before this handler is called. if scope.Session != nil { diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index 89eadc5d..4a8fd9da 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -104,8 +104,7 @@ var _ = Describe("Stored Session Suite", func() { // Set up the request with the request headesr and a request scope req := httptest.NewRequest("", "/", nil) req.Header = in.requestHeaders - contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) - req = req.WithContext(contextWithScope) + req = middlewareapi.AddRequestScope(req, scope) rw := httptest.NewRecorder() @@ -120,7 +119,7 @@ var _ = Describe("Stored Session Suite", func() { // from the scope var gotSession *sessionsapi.SessionState handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session + gotSession = middlewareapi.GetRequestScope(r).Session })) handler.ServeHTTP(rw, req) diff --git a/pkg/requests/util/util.go b/pkg/requests/util/util.go new file mode 100644 index 00000000..08c9c2c1 --- /dev/null +++ b/pkg/requests/util/util.go @@ -0,0 +1,48 @@ +package util + +import ( + "net/http" + + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" +) + +// GetRequestProto returns the request scheme or X-Forwarded-Proto if present +// and the request is proxied. +func GetRequestProto(req *http.Request) string { + proto := req.Header.Get("X-Forwarded-Proto") + if !IsProxied(req) || proto == "" { + proto = req.URL.Scheme + } + return proto +} + +// GetRequestHost returns the request host header or X-Forwarded-Host if +// present and the request is proxied. +func GetRequestHost(req *http.Request) string { + host := req.Header.Get("X-Forwarded-Host") + if !IsProxied(req) || host == "" { + host = req.Host + } + return host +} + +// GetRequestURI return the request URI or X-Forwarded-Uri if present and the +// request is proxied. +func GetRequestURI(req *http.Request) string { + uri := req.Header.Get("X-Forwarded-Uri") + if !IsProxied(req) || uri == "" { + // Use RequestURI to preserve ?query + uri = req.URL.RequestURI() + } + return uri +} + +// IsProxied determines if a request was from a proxy based on the RequestScope +// ReverseProxy tracker. +func IsProxied(req *http.Request) bool { + scope := middlewareapi.GetRequestScope(req) + if scope == nil { + return false + } + return scope.ReverseProxy +} diff --git a/pkg/requests/util/util_suite_test.go b/pkg/requests/util/util_suite_test.go new file mode 100644 index 00000000..a03f943f --- /dev/null +++ b/pkg/requests/util/util_suite_test.go @@ -0,0 +1,19 @@ +package util_test + +import ( + "testing" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +// TestRequestUtilSuite and related tests are in a *_test package +// to prevent circular imports with the `logger` package which uses +// this functionality +func TestRequestUtilSuite(t *testing.T) { + logger.SetOutput(GinkgoWriter) + + RegisterFailHandler(Fail) + RunSpecs(t, "Request Utils") +} diff --git a/pkg/requests/util/util_test.go b/pkg/requests/util/util_test.go new file mode 100644 index 00000000..595f93f6 --- /dev/null +++ b/pkg/requests/util/util_test.go @@ -0,0 +1,131 @@ +package util_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Util Suite", func() { + const ( + proto = "http" + host = "www.oauth2proxy.test" + uri = "/test/endpoint" + ) + var req *http.Request + + BeforeEach(func() { + req = httptest.NewRequest( + http.MethodGet, + fmt.Sprintf("%s://%s%s", proto, host, uri), + nil, + ) + }) + + Context("GetRequestHost", func() { + Context("IsProxied is false", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{}) + }) + + It("returns the host", func() { + Expect(util.GetRequestHost(req)).To(Equal(host)) + }) + + It("ignores X-Forwarded-Host and returns the host", func() { + req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text") + Expect(util.GetRequestHost(req)).To(Equal(host)) + }) + }) + + Context("IsProxied is true", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{ + ReverseProxy: true, + }) + }) + + It("returns the host if X-Forwarded-Host is not present", func() { + Expect(util.GetRequestHost(req)).To(Equal(host)) + }) + + It("returns the X-Forwarded-Host when present", func() { + req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text") + Expect(util.GetRequestHost(req)).To(Equal("external.oauth2proxy.text")) + }) + }) + }) + + Context("GetRequestProto", func() { + Context("IsProxied is false", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{}) + }) + + It("returns the scheme", func() { + Expect(util.GetRequestProto(req)).To(Equal(proto)) + }) + + It("ignores X-Forwarded-Proto and returns the scheme", func() { + req.Header.Add("X-Forwarded-Proto", "https") + Expect(util.GetRequestProto(req)).To(Equal(proto)) + }) + }) + + Context("IsProxied is true", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{ + ReverseProxy: true, + }) + }) + + It("returns the scheme if X-Forwarded-Proto is not present", func() { + Expect(util.GetRequestProto(req)).To(Equal(proto)) + }) + + It("returns the X-Forwarded-Proto when present", func() { + req.Header.Add("X-Forwarded-Proto", "https") + Expect(util.GetRequestProto(req)).To(Equal("https")) + }) + }) + }) + + Context("GetRequestURI", func() { + Context("IsProxied is false", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{}) + }) + + It("returns the URI", func() { + Expect(util.GetRequestURI(req)).To(Equal(uri)) + }) + + It("ignores X-Forwarded-Uri and returns the URI", func() { + req.Header.Add("X-Forwarded-Uri", "/some/other/path") + Expect(util.GetRequestURI(req)).To(Equal(uri)) + }) + }) + + Context("IsProxied is true", func() { + BeforeEach(func() { + req = middleware.AddRequestScope(req, &middleware.RequestScope{ + ReverseProxy: true, + }) + }) + + It("returns the URI if X-Forwarded-Uri is not present", func() { + Expect(util.GetRequestURI(req)).To(Equal(uri)) + }) + + It("returns the X-Forwarded-Uri when present", func() { + req.Header.Add("X-Forwarded-Uri", "/some/other/path") + Expect(util.GetRequestURI(req)).To(Equal("/some/other/path")) + }) + }) + }) +}) diff --git a/pkg/util/util.go b/pkg/util/util.go index 452e14f1..4519fdb8 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -4,9 +4,6 @@ import ( "crypto/x509" "fmt" "io/ioutil" - "net/http" - - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" ) func GetCertPool(paths []string) (*x509.CertPool, error) { @@ -26,37 +23,3 @@ func GetCertPool(paths []string) (*x509.CertPool, error) { } return pool, nil } - -// GetRequestProto return the request host header or X-Forwarded-Proto if present -func GetRequestProto(req *http.Request) string { - proto := req.Header.Get("X-Forwarded-Proto") - if !isProxied(req) || proto == "" { - proto = req.URL.Scheme - } - return proto -} - -// GetRequestHost return the request host header or X-Forwarded-Host if present -// and reverse proxy mode is enabled. -func GetRequestHost(req *http.Request) string { - host := req.Header.Get("X-Forwarded-Host") - if !isProxied(req) || host == "" { - host = req.Host - } - return host -} - -// GetRequestURI return the request host header or X-Forwarded-Uri if present -func GetRequestURI(req *http.Request) string { - uri := req.Header.Get("X-Forwarded-Uri") - if !isProxied(req) || uri == "" { - // Use RequestURI to preserve ?query - uri = req.URL.RequestURI() - } - return uri -} - -func isProxied(req *http.Request) bool { - scope := middleware.GetRequestScope(req) - return scope.ReverseProxy -} diff --git a/pkg/util/util_test.go b/pkg/util/util_test.go index d032025e..347f41bb 100644 --- a/pkg/util/util_test.go +++ b/pkg/util/util_test.go @@ -4,11 +4,9 @@ import ( "crypto/x509/pkix" "encoding/asn1" "io/ioutil" - "net/http/httptest" "os" "testing" - . "github.com/onsi/gomega" "github.com/stretchr/testify/assert" ) @@ -97,42 +95,3 @@ func TestGetCertPool(t *testing.T) { expectedSubjects := []string{testCA1Subj, testCA2Subj} assert.Equal(t, expectedSubjects, got) } - -func TestGetRequestHost(t *testing.T) { - g := NewWithT(t) - - req := httptest.NewRequest("GET", "https://example.com", nil) - host := GetRequestHost(req) - g.Expect(host).To(Equal("example.com")) - - proxyReq := httptest.NewRequest("GET", "http://internal.example.com", nil) - proxyReq.Header.Add("X-Forwarded-Host", "external.example.com") - extHost := GetRequestHost(proxyReq) - g.Expect(extHost).To(Equal("external.example.com")) -} - -func TestGetRequestProto(t *testing.T) { - g := NewWithT(t) - - req := httptest.NewRequest("GET", "https://example.com", nil) - proto := GetRequestProto(req) - g.Expect(proto).To(Equal("https")) - - proxyReq := httptest.NewRequest("GET", "https://internal.example.com", nil) - proxyReq.Header.Add("X-Forwarded-Proto", "http") - extProto := GetRequestProto(proxyReq) - g.Expect(extProto).To(Equal("http")) -} - -func TestGetRequestURI(t *testing.T) { - g := NewWithT(t) - - req := httptest.NewRequest("GET", "https://example.com/ping", nil) - uri := GetRequestURI(req) - g.Expect(uri).To(Equal("/ping")) - - proxyReq := httptest.NewRequest("GET", "http://internal.example.com/bong", nil) - proxyReq.Header.Add("X-Forwarded-Uri", "/ping") - extURI := GetRequestURI(proxyReq) - g.Expect(extURI).To(Equal("/ping")) -} From f054682fb7a879d7fc285d4fdeeb243e6d92f9f0 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 2 Jan 2021 13:16:45 -0800 Subject: [PATCH 25/28] Make HTTPS Redirect middleware Reverse Proxy aware --- pkg/middleware/redirect_to_https.go | 13 +++++----- pkg/middleware/redirect_to_https_test.go | 31 ++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/pkg/middleware/redirect_to_https.go b/pkg/middleware/redirect_to_https.go index 18b4b967..72f9dac4 100644 --- a/pkg/middleware/redirect_to_https.go +++ b/pkg/middleware/redirect_to_https.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/justinas/alice" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" + requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util" ) const httpsScheme = "https" @@ -26,10 +26,11 @@ func NewRedirectToHTTPS(httpsPort string) alice.Constructor { // to the port from the httpsAddress given. func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - proto := req.Header.Get("X-Forwarded-Proto") - if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == "") { - // Only care about the connection to us being HTTPS if the proto is empty, - // otherwise the proto is source of truth + proto := requestutil.GetRequestProto(req) + if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == req.URL.Scheme) { + // Only care about the connection to us being HTTPS if the proto wasn't + // from a trusted `X-Forwarded-Proto` (proto == req.URL.Scheme). + // Otherwise the proto is source of truth next.ServeHTTP(rw, req) return } @@ -41,7 +42,7 @@ func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler { // Set the Host in case the targetURL still does not have one // or it isn't X-Forwarded-Host aware - targetURL.Host = util.GetRequestHost(req) + targetURL.Host = requestutil.GetRequestHost(req) // Overwrite the port if the original request was to a non-standard port if targetURL.Port() != "" { diff --git a/pkg/middleware/redirect_to_https_test.go b/pkg/middleware/redirect_to_https_test.go index ca8bdb99..f8c2c6bb 100644 --- a/pkg/middleware/redirect_to_https_test.go +++ b/pkg/middleware/redirect_to_https_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http/httptest" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" . "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" @@ -21,6 +22,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { requestString string useTLS bool headers map[string]string + reverseProxy bool expectedStatus int expectedBody string expectedLocation string @@ -35,6 +37,10 @@ var _ = Describe("RedirectToHTTPS suite", func() { if in.useTLS { req.TLS = &tls.ConnectionState{} } + scope := &middlewareapi.RequestScope{ + ReverseProxy: in.reverseProxy, + } + req = middlewareapi.AddRequestScope(req, scope) rw := httptest.NewRecorder() @@ -52,6 +58,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { requestString: "http://example.com", useTLS: false, headers: map[string]string{}, + reverseProxy: false, expectedStatus: 308, expectedBody: permanentRedirectBody("https://example.com"), expectedLocation: "https://example.com", @@ -60,6 +67,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { requestString: "https://example.com", useTLS: true, headers: map[string]string{}, + reverseProxy: false, expectedStatus: 200, expectedBody: "test", }), @@ -69,15 +77,28 @@ var _ = Describe("RedirectToHTTPS suite", func() { headers: map[string]string{ "X-Forwarded-Proto": "HTTPS", }, + reverseProxy: true, expectedStatus: 200, expectedBody: "test", }), + Entry("without TLS and X-Forwarded-Proto=HTTPS but ReverseProxy not set", &requestTableInput{ + requestString: "http://example.com", + useTLS: false, + headers: map[string]string{ + "X-Forwarded-Proto": "HTTPS", + }, + reverseProxy: false, + expectedStatus: 308, + expectedBody: permanentRedirectBody("https://example.com"), + expectedLocation: "https://example.com", + }), Entry("with TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{ requestString: "https://example.com", useTLS: true, headers: map[string]string{ "X-Forwarded-Proto": "HTTPS", }, + reverseProxy: true, expectedStatus: 200, expectedBody: "test", }), @@ -87,6 +108,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { headers: map[string]string{ "X-Forwarded-Proto": "https", }, + reverseProxy: true, expectedStatus: 200, expectedBody: "test", }), @@ -96,6 +118,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { headers: map[string]string{ "X-Forwarded-Proto": "https", }, + reverseProxy: true, expectedStatus: 200, expectedBody: "test", }), @@ -105,6 +128,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { headers: map[string]string{ "X-Forwarded-Proto": "HTTP", }, + reverseProxy: true, expectedStatus: 308, expectedBody: permanentRedirectBody("https://example.com"), expectedLocation: "https://example.com", @@ -115,6 +139,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { headers: map[string]string{ "X-Forwarded-Proto": "HTTP", }, + reverseProxy: true, expectedStatus: 308, expectedBody: permanentRedirectBody("https://example.com"), expectedLocation: "https://example.com", @@ -125,6 +150,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { headers: map[string]string{ "X-Forwarded-Proto": "http", }, + reverseProxy: true, expectedStatus: 308, expectedBody: permanentRedirectBody("https://example.com"), expectedLocation: "https://example.com", @@ -135,6 +161,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { headers: map[string]string{ "X-Forwarded-Proto": "http", }, + reverseProxy: true, expectedStatus: 308, expectedBody: permanentRedirectBody("https://example.com"), expectedLocation: "https://example.com", @@ -143,6 +170,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { requestString: "http://example.com:8080", useTLS: false, headers: map[string]string{}, + reverseProxy: false, expectedStatus: 308, expectedBody: permanentRedirectBody("https://example.com:8443"), expectedLocation: "https://example.com:8443", @@ -151,6 +179,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { requestString: "https://example.com:8443", useTLS: true, headers: map[string]string{}, + reverseProxy: false, expectedStatus: 200, expectedBody: "test", }), @@ -161,6 +190,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { requestString: "/", useTLS: false, expectedStatus: 308, + reverseProxy: false, expectedBody: permanentRedirectBody("https://example.com/"), expectedLocation: "https://example.com/", }), @@ -171,6 +201,7 @@ var _ = Describe("RedirectToHTTPS suite", func() { "X-Forwarded-Proto": "HTTP", "X-Forwarded-Host": "external.example.com", }, + reverseProxy: true, expectedStatus: 308, expectedBody: permanentRedirectBody("https://external.example.com"), expectedLocation: "https://external.example.com", From 73fc7706bca4c4be241006c7f0ac2bbcaa160ec0 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 2 Jan 2021 13:46:05 -0800 Subject: [PATCH 26/28] Figure out final app redirect URL with proxy aware request utils --- oauthproxy.go | 150 ++++++++++++++++++++++++++++++--------------- oauthproxy_test.go | 20 +++--- 2 files changed, 111 insertions(+), 59 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index b51b9bea..a595bc3b 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -24,9 +24,9 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" + requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/upstream" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" "github.com/oauth2-proxy/oauth2-proxy/v7/providers" ) @@ -98,7 +98,6 @@ type OAuthProxy struct { SetAuthorization bool PassAuthorization bool PreferEmailToUser bool - ReverseProxy bool skipAuthPreflight bool skipJwtBearerTokens bool templates *template.Template @@ -201,7 +200,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr UserInfoPath: fmt.Sprintf("%s/userinfo", opts.ProxyPrefix), ProxyPrefix: opts.ProxyPrefix, - ReverseProxy: opts.ReverseProxy, provider: opts.GetProvider(), providerNameOverride: opts.ProviderName, sessionStore: sessionStore, @@ -231,7 +229,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr // the OAuth2 Proxy authentication logic kicks in. // For example forcing HTTPS or health checks. func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { - chain := alice.New(middleware.NewScope(opts)) + chain := alice.New(middleware.NewScope(opts.ReverseProxy)) if opts.ForceHTTPS { _, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress) @@ -368,9 +366,9 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) { return routes, nil } -// GetRedirectURI returns the redirectURL that the upstream OAuth Provider will +// GetOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will // redirect clients to once authenticated -func (p *OAuthProxy) GetRedirectURI(host string) string { +func (p *OAuthProxy) GetOAuthRedirectURI(host string) string { // default to the request Host if not set if p.redirectURL.Host != "" { return p.redirectURL.String() @@ -391,7 +389,7 @@ func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessio if code == "" { return nil, providers.ErrMissingCode } - redirectURI := p.GetRedirectURI(host) + redirectURI := p.GetOAuthRedirectURI(host) s, err := p.provider.Redeem(ctx, redirectURI, code) if err != nil { return nil, err @@ -420,7 +418,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains) if cookieDomain != "" { - domain := util.GetRequestHost(req) + domain := requestutil.GetRequestHost(req) if h, _, err := net.SplitHostPort(domain); err == nil { domain = h } @@ -509,7 +507,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code } rw.WriteHeader(code) - redirectURL, err := p.GetRedirect(req) + redirectURL, err := p.GetAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) @@ -568,46 +566,108 @@ func (p *OAuthProxy) ManualSignIn(req *http.Request) (string, bool) { return "", false } -// GetRedirect reads the query parameter to get the URL to redirect clients to +// GetAppRedirect determines the full URL or URI path to redirect clients to // once authenticated with the OAuthProxy -func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) { - err = req.ParseForm() +// Strategy priority (first legal result is used): +// - `rd` querysting parameter +// - `X-Auth-Request-Redirect` header +// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled) +// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*) +// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled) +// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*) +// - `/` +func (p *OAuthProxy) GetAppRedirect(req *http.Request) (string, error) { + err := req.ParseForm() if err != nil { - return + return "", err } - redirect = req.Header.Get("X-Auth-Request-Redirect") - if req.Form.Get("rd") != "" { - redirect = req.Form.Get("rd") - } - // Quirk: On reverse proxies that doesn't have support for - // "X-Auth-Request-Redirect" header or dynamic header/query string - // manipulation (like Traefik v1 and v2), we can try if the header - // X-Forwarded-Host exists or not. - if redirect == "" && isForwardedRequest(req, p.ReverseProxy) { - redirect = p.getRedirectFromForwardHeaders(req) - } - if !p.IsValidRedirect(redirect) { - // Use RequestURI to preserve ?query - redirect = req.URL.RequestURI() - - if strings.HasPrefix(redirect, fmt.Sprintf("%s/", p.ProxyPrefix)) { - redirect = "/" + // These redirect getter functions are strategies ordered by priority + // for figuring out the redirect URL. + type redirectGetter func(req *http.Request) string + for _, rdGetter := range []redirectGetter{ + p.getRdQuerystringRedirect, + p.getXAuthRequestRedirect, + p.getXForwardedHeadersRedirect, + p.getURIRedirect, + } { + if redirect := rdGetter(req); redirect != "" { + return redirect, nil } } - return + return "/", nil } -// getRedirectFromForwardHeaders returns the redirect URL based on X-Forwarded-{Proto,Host,Uri} headers -func (p *OAuthProxy) getRedirectFromForwardHeaders(req *http.Request) string { - uri := util.GetRequestURI(req) +func isForwardedRequest(req *http.Request) bool { + return requestutil.IsProxied(req) && + req.Host != requestutil.GetRequestHost(req) +} - if strings.HasPrefix(uri, fmt.Sprintf("%s/", p.ProxyPrefix)) { +func (p *OAuthProxy) hasProxyPrefix(path string) bool { + return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix)) +} + +// getRdQuerystringRedirect handles this GetAppRedirect strategy: +// - `rd` querysting parameter +func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string { + redirect := req.Form.Get("rd") + if p.IsValidRedirect(redirect) { + return redirect + } + return "" +} + +// getXAuthRequestRedirect handles this GetAppRedirect strategy: +// - `X-Auth-Request-Redirect` Header +func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string { + redirect := req.Header.Get("X-Auth-Request-Redirect") + if p.IsValidRedirect(redirect) { + return redirect + } + return "" +} + +// getXForwardedHeadersRedirect handles these GetAppRedirect strategies: +// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled) +// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*) +func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string { + if !isForwardedRequest(req) { + return "" + } + + uri := requestutil.GetRequestURI(req) + if p.hasProxyPrefix(uri) { uri = "/" } - return fmt.Sprintf("%s://%s%s", util.GetRequestProto(req), util.GetRequestHost(req), uri) + redirect := fmt.Sprintf( + "%s://%s%s", + requestutil.GetRequestProto(req), + requestutil.GetRequestHost(req), + uri, + ) + + if p.IsValidRedirect(redirect) { + return redirect + } + return "" +} + +// getURIRedirect handles these GetAppRedirect strategies: +// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled) +// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*) +// - `/` +func (p *OAuthProxy) getURIRedirect(req *http.Request) string { + redirect := requestutil.GetRequestURI(req) + if !p.IsValidRedirect(redirect) { + redirect = req.URL.RequestURI() + } + + if p.hasProxyPrefix(redirect) { + return "/" + } + return redirect } // splitHostPort separates host and port. If the port is not valid, it returns @@ -707,12 +767,6 @@ func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { return false } -// isForwardedRequest is used to check if X-Forwarded-Host header exists or not -func isForwardedRequest(req *http.Request, reverseProxy bool) bool { - isForwarded := req.Host != util.GetRequestHost(req) - return isForwarded && reverseProxy -} - // See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en var noCacheHeaders = map[string]string{ "Expires": time.Unix(0, 0).Format(time.RFC1123), @@ -781,7 +835,7 @@ func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { // SignIn serves a page prompting users to sign in func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { - redirect, err := p.GetRedirect(req) + redirect, err := p.GetAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) @@ -839,7 +893,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) { // SignOut sends a response to clear the authentication cookie func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { - redirect, err := p.GetRedirect(req) + redirect, err := p.GetAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) @@ -864,13 +918,13 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { return } p.SetCSRFCookie(rw, req, nonce) - redirect, err := p.GetRedirect(req) + redirect, err := p.GetAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } - redirectURI := p.GetRedirectURI(util.GetRequestHost(req)) + redirectURI := p.GetOAuthRedirectURI(requestutil.GetRequestHost(req)) http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound) } @@ -893,7 +947,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { return } - session, err := p.redeemCode(req.Context(), util.GetRequestHost(req), req.Form.Get("code")) + session, err := p.redeemCode(req.Context(), requestutil.GetRequestHost(req), req.Form.Get("code")) if err != nil { logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") @@ -1024,7 +1078,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R var session *sessionsapi.SessionState getSession := p.sessionChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - session = middleware.GetRequestScope(req).Session + session = middlewareapi.GetRequestScope(req).Session })) getSession.ServeHTTP(rw, req) diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 52ffa2b3..8adea1ce 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -19,6 +19,7 @@ import ( "github.com/coreos/go-oidc" "github.com/mbland/hmacauth" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -1750,8 +1751,7 @@ func TestRequestSignature(t *testing.T) { func TestGetRedirect(t *testing.T) { opts := baseTestOptions() - opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com") - opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com:8443") + opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com", ".example.com:8443") err := validation.Validate(opts) assert.NoError(t, err) require.NotEmpty(t, opts.ProxyPrefix) @@ -1854,9 +1854,6 @@ func TestGetRedirect(t *testing.T) { url: "https://oauth.example.com/foo/bar", headers: map[string]string{ "X-Auth-Request-Redirect": "https://a-service.example.com/foo/bar", - "X-Forwarded-Proto": "", - "X-Forwarded-Host": "", - "X-Forwarded-Uri": "", }, reverseProxy: true, expectedRedirect: "https://a-service.example.com/foo/bar", @@ -1884,10 +1881,9 @@ func TestGetRedirect(t *testing.T) { name: "proxied request with rd query string and some headers set redirects to proxied URL on rd query string", url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fbaz", headers: map[string]string{ - "X-Auth-Request-Redirect": "", - "X-Forwarded-Proto": "https", - "X-Forwarded-Host": "another-service.example.com", - "X-Forwarded-Uri": "/seasons/greetings", + "X-Forwarded-Proto": "https", + "X-Forwarded-Host": "another-service.example.com", + "X-Forwarded-Uri": "/seasons/greetings", }, reverseProxy: true, expectedRedirect: "https://a-service.example.com/foo/baz", @@ -1901,8 +1897,10 @@ func TestGetRedirect(t *testing.T) { req.Header.Add(header, value) } } - proxy.ReverseProxy = tt.reverseProxy - redirect, err := proxy.GetRedirect(req) + req = middleware.AddRequestScope(req, &middleware.RequestScope{ + ReverseProxy: tt.reverseProxy, + }) + redirect, err := proxy.GetAppRedirect(req) assert.NoError(t, err) assert.Equal(t, tt.expectedRedirect, redirect) From fa6a785eafa29ecb32764c52ca75807f501cf090 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 2 Jan 2021 14:20:48 -0800 Subject: [PATCH 27/28] Improve handler vs helper organization in oauthproxy.go Additionally, convert a lot of helper methods to be private --- oauthproxy.go | 652 +++++++++++++++++++++++---------------------- oauthproxy_test.go | 9 +- 2 files changed, 332 insertions(+), 329 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index a595bc3b..28f667b3 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -31,9 +31,7 @@ import ( ) const ( - httpScheme = "http" - httpsScheme = "https" - + schemeHTTPS = "https" applicationJSON = "application/json" ) @@ -366,49 +364,6 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) { return routes, nil } -// GetOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will -// redirect clients to once authenticated -func (p *OAuthProxy) GetOAuthRedirectURI(host string) string { - // default to the request Host if not set - if p.redirectURL.Host != "" { - return p.redirectURL.String() - } - u := *p.redirectURL - if u.Scheme == "" { - if p.CookieSecure { - u.Scheme = httpsScheme - } else { - u.Scheme = httpScheme - } - } - u.Host = host - return u.String() -} - -func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) { - if code == "" { - return nil, providers.ErrMissingCode - } - redirectURI := p.GetOAuthRedirectURI(host) - s, err := p.provider.Redeem(ctx, redirectURI, code) - if err != nil { - return nil, err - } - return s, nil -} - -func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error { - var err error - if s.Email == "" { - s.Email, err = p.provider.GetEmailAddress(ctx, s) - if err != nil && !errors.Is(err, providers.ErrNotImplemented) { - return err - } - } - - return p.provider.EnrichSession(ctx, s) -} - // MakeCSRFCookie creates a cookie for CSRF func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) @@ -466,6 +421,81 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *s return p.sessionStore.Save(rw, req, s) } +// IsValidRedirect checks whether the redirect URL is whitelisted +func (p *OAuthProxy) IsValidRedirect(redirect string) bool { + switch { + case redirect == "": + // The user didn't specify a redirect, should fallback to `/` + return false + case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect): + return true + case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"): + redirectURL, err := url.Parse(redirect) + if err != nil { + logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect) + return false + } + redirectHostname := redirectURL.Hostname() + + for _, domain := range p.whitelistDomains { + domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, ".")) + if domainHostname == "" { + continue + } + + if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) { + // the domain names match, now validate the ports + // if the whitelisted domain's port is '*', allow all ports + // if the whitelisted domain contains a specific port, only allow that port + // if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https + redirectPort := redirectURL.Port() + if (domainPort == "*") || + (domainPort == redirectPort) || + (domainPort == "" && redirectPort == "") { + return true + } + } + } + + logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect) + return false + default: + logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect) + return false + } +} + +func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req) +} + +func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { + if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) { + prepareNoCache(rw) + } + + switch path := req.URL.Path; { + case path == p.RobotsPath: + p.RobotsTxt(rw) + case p.IsAllowedRequest(req): + p.SkipAuthProxy(rw, req) + case path == p.SignInPath: + p.SignIn(rw, req) + case path == p.SignOutPath: + p.SignOut(rw, req) + case path == p.OAuthStartPath: + p.OAuthStart(rw, req) + case path == p.OAuthCallbackPath: + p.OAuthCallback(rw, req) + case path == p.AuthOnlyPath: + p.AuthOnly(rw, req) + case path == p.UserInfoPath: + p.UserInfo(rw, req) + default: + p.Proxy(rw, req) + } +} + // RobotsTxt disallows scraping pages from the OAuthProxy func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) { _, err := fmt.Fprintf(rw, "User-agent: *\nDisallow: /") @@ -496,6 +526,42 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m } } +// IsAllowedRequest is used to check if auth should be skipped for this request +func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool { + isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" + return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.isTrustedIP(req) +} + +// IsAllowedRoute is used to check if the request method & path is allowed without auth +func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { + for _, route := range p.allowedRoutes { + if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) { + return true + } + } + return false +} + +// isTrustedIP is used to check if a request comes from a trusted client IP address. +func (p *OAuthProxy) isTrustedIP(req *http.Request) bool { + if p.trustedIPs == nil { + return false + } + + remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req) + if err != nil { + logger.Errorf("Error obtaining real IP for trusted IP list: %v", err) + // Possibly spoofed X-Real-IP header + return false + } + + if remoteAddr == nil { + return false + } + + return p.trustedIPs.Has(remoteAddr) +} + // SignInPage writes the sing in template to the response func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) { prepareNoCache(rw) @@ -507,7 +573,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code } rw.WriteHeader(code) - redirectURL, err := p.GetAppRedirect(req) + redirectURL, err := p.getAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) @@ -566,276 +632,9 @@ func (p *OAuthProxy) ManualSignIn(req *http.Request) (string, bool) { return "", false } -// GetAppRedirect determines the full URL or URI path to redirect clients to -// once authenticated with the OAuthProxy -// Strategy priority (first legal result is used): -// - `rd` querysting parameter -// - `X-Auth-Request-Redirect` header -// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled) -// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*) -// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled) -// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*) -// - `/` -func (p *OAuthProxy) GetAppRedirect(req *http.Request) (string, error) { - err := req.ParseForm() - if err != nil { - return "", err - } - - // These redirect getter functions are strategies ordered by priority - // for figuring out the redirect URL. - type redirectGetter func(req *http.Request) string - for _, rdGetter := range []redirectGetter{ - p.getRdQuerystringRedirect, - p.getXAuthRequestRedirect, - p.getXForwardedHeadersRedirect, - p.getURIRedirect, - } { - if redirect := rdGetter(req); redirect != "" { - return redirect, nil - } - } - - return "/", nil -} - -func isForwardedRequest(req *http.Request) bool { - return requestutil.IsProxied(req) && - req.Host != requestutil.GetRequestHost(req) -} - -func (p *OAuthProxy) hasProxyPrefix(path string) bool { - return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix)) -} - -// getRdQuerystringRedirect handles this GetAppRedirect strategy: -// - `rd` querysting parameter -func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string { - redirect := req.Form.Get("rd") - if p.IsValidRedirect(redirect) { - return redirect - } - return "" -} - -// getXAuthRequestRedirect handles this GetAppRedirect strategy: -// - `X-Auth-Request-Redirect` Header -func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string { - redirect := req.Header.Get("X-Auth-Request-Redirect") - if p.IsValidRedirect(redirect) { - return redirect - } - return "" -} - -// getXForwardedHeadersRedirect handles these GetAppRedirect strategies: -// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled) -// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*) -func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string { - if !isForwardedRequest(req) { - return "" - } - - uri := requestutil.GetRequestURI(req) - if p.hasProxyPrefix(uri) { - uri = "/" - } - - redirect := fmt.Sprintf( - "%s://%s%s", - requestutil.GetRequestProto(req), - requestutil.GetRequestHost(req), - uri, - ) - - if p.IsValidRedirect(redirect) { - return redirect - } - return "" -} - -// getURIRedirect handles these GetAppRedirect strategies: -// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled) -// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*) -// - `/` -func (p *OAuthProxy) getURIRedirect(req *http.Request) string { - redirect := requestutil.GetRequestURI(req) - if !p.IsValidRedirect(redirect) { - redirect = req.URL.RequestURI() - } - - if p.hasProxyPrefix(redirect) { - return "/" - } - return redirect -} - -// splitHostPort separates host and port. If the port is not valid, it returns -// the entire input as host, and it doesn't check the validity of the host. -// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric. -// *** taken from net/url, modified validOptionalPort() to accept ":*" -func splitHostPort(hostport string) (host, port string) { - host = hostport - - colon := strings.LastIndexByte(host, ':') - if colon != -1 && validOptionalPort(host[colon:]) { - host, port = host[:colon], host[colon+1:] - } - - if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { - host = host[1 : len(host)-1] - } - - return -} - -// validOptionalPort reports whether port is either an empty string -// or matches /^:\d*$/ -// *** taken from net/url, modified to accept ":*" -func validOptionalPort(port string) bool { - if port == "" || port == ":*" { - return true - } - if port[0] != ':' { - return false - } - for _, b := range port[1:] { - if b < '0' || b > '9' { - return false - } - } - return true -} - -// IsValidRedirect checks whether the redirect URL is whitelisted -func (p *OAuthProxy) IsValidRedirect(redirect string) bool { - switch { - case redirect == "": - // The user didn't specify a redirect, should fallback to `/` - return false - case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect): - return true - case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"): - redirectURL, err := url.Parse(redirect) - if err != nil { - logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect) - return false - } - redirectHostname := redirectURL.Hostname() - - for _, domain := range p.whitelistDomains { - domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, ".")) - if domainHostname == "" { - continue - } - - if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) { - // the domain names match, now validate the ports - // if the whitelisted domain's port is '*', allow all ports - // if the whitelisted domain contains a specific port, only allow that port - // if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https - redirectPort := redirectURL.Port() - if (domainPort == "*") || - (domainPort == redirectPort) || - (domainPort == "" && redirectPort == "") { - return true - } - } - } - - logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect) - return false - default: - logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect) - return false - } -} - -// IsAllowedRequest is used to check if auth should be skipped for this request -func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool { - isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" - return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.IsTrustedIP(req) -} - -// IsAllowedRoute is used to check if the request method & path is allowed without auth -func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool { - for _, route := range p.allowedRoutes { - if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) { - return true - } - } - return false -} - -// See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en -var noCacheHeaders = map[string]string{ - "Expires": time.Unix(0, 0).Format(time.RFC1123), - "Cache-Control": "no-cache, no-store, must-revalidate, max-age=0", - "X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/ -} - -// prepareNoCache prepares headers for preventing browser caching. -func prepareNoCache(w http.ResponseWriter) { - // Set NoCache headers - for k, v := range noCacheHeaders { - w.Header().Set(k, v) - } -} - -// IsTrustedIP is used to check if a request comes from a trusted client IP address. -func (p *OAuthProxy) IsTrustedIP(req *http.Request) bool { - if p.trustedIPs == nil { - return false - } - - remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req) - if err != nil { - logger.Errorf("Error obtaining real IP for trusted IP list: %v", err) - // Possibly spoofed X-Real-IP header - return false - } - - if remoteAddr == nil { - return false - } - - return p.trustedIPs.Has(remoteAddr) -} - -func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req) -} - -func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { - if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) { - prepareNoCache(rw) - } - - switch path := req.URL.Path; { - case path == p.RobotsPath: - p.RobotsTxt(rw) - case p.IsAllowedRequest(req): - p.SkipAuthProxy(rw, req) - case path == p.SignInPath: - p.SignIn(rw, req) - case path == p.SignOutPath: - p.SignOut(rw, req) - case path == p.OAuthStartPath: - p.OAuthStart(rw, req) - case path == p.OAuthCallbackPath: - p.OAuthCallback(rw, req) - case path == p.AuthOnlyPath: - p.AuthOnly(rw, req) - case path == p.UserInfoPath: - p.UserInfo(rw, req) - default: - p.Proxy(rw, req) - } -} - // SignIn serves a page prompting users to sign in func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { - redirect, err := p.GetAppRedirect(req) + redirect, err := p.getAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) @@ -893,7 +692,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) { // SignOut sends a response to clear the authentication cookie func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { - redirect, err := p.GetAppRedirect(req) + redirect, err := p.getAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) @@ -918,13 +717,13 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { return } p.SetCSRFCookie(rw, req, nonce) - redirect, err := p.GetAppRedirect(req) + redirect, err := p.getAppRedirect(req) if err != nil { logger.Errorf("Error obtaining redirect: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } - redirectURI := p.GetOAuthRedirectURI(requestutil.GetRequestHost(req)) + redirectURI := p.getOAuthRedirectURI(req) http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound) } @@ -947,7 +746,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { return } - session, err := p.redeemCode(req.Context(), requestutil.GetRequestHost(req), req.Form.Get("code")) + session, err := p.redeemCode(req) if err != nil { logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") @@ -1006,6 +805,32 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { } } +func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, error) { + code := req.Form.Get("code") + if code == "" { + return nil, providers.ErrMissingCode + } + + redirectURI := p.getOAuthRedirectURI(req) + s, err := p.provider.Redeem(req.Context(), redirectURI, code) + if err != nil { + return nil, err + } + return s, nil +} + +func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error { + var err error + if s.Email == "" { + s.Email, err = p.provider.GetEmailAddress(ctx, s) + if err != nil && !errors.Is(err, providers.ErrNotImplemented) { + return err + } + } + + return p.provider.EnrichSession(ctx, s) +} + // AuthOnly checks whether the user is currently logged in (both authentication // and optional authorization). func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) { @@ -1023,7 +848,7 @@ func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) { } // we are authenticated - p.addHeadersForProxying(rw, req, session) + p.addHeadersForProxying(rw, session) p.headersChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusAccepted) })).ServeHTTP(rw, req) @@ -1041,13 +866,13 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { switch err { case nil: // we are authenticated - p.addHeadersForProxying(rw, req, session) + p.addHeadersForProxying(rw, session) p.headersChain.Then(p.serveMux).ServeHTTP(rw, req) case ErrNeedsLogin: // we need to send the user to a login screen if isAjax(req) { // no point redirecting an AJAX request - p.ErrorJSON(rw, http.StatusUnauthorized) + p.errorJSON(rw, http.StatusUnauthorized) return } @@ -1066,7 +891,184 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { p.ErrorPage(rw, http.StatusInternalServerError, "Internal Error", "Internal Error") } +} +// See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en +var noCacheHeaders = map[string]string{ + "Expires": time.Unix(0, 0).Format(time.RFC1123), + "Cache-Control": "no-cache, no-store, must-revalidate, max-age=0", + "X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/ +} + +// prepareNoCache prepares headers for preventing browser caching. +func prepareNoCache(w http.ResponseWriter) { + // Set NoCache headers + for k, v := range noCacheHeaders { + w.Header().Set(k, v) + } +} + +// getOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will +// redirect clients to once authenticated. +// This is usually the OAuthProxy callback URL. +func (p *OAuthProxy) getOAuthRedirectURI(req *http.Request) string { + // if `p.redirectURL` already has a host, return it + if p.redirectURL.Host != "" { + return p.redirectURL.String() + } + + // Otherwise figure out the scheme + host from the request + rd := *p.redirectURL + rd.Host = requestutil.GetRequestHost(req) + rd.Scheme = requestutil.GetRequestProto(req) + + // If CookieSecure is true, return `https` no matter what + // Not all reverse proxies set X-Forwarded-Proto + if p.CookieSecure { + rd.Scheme = schemeHTTPS + } + return rd.String() +} + +// getAppRedirect determines the full URL or URI path to redirect clients to +// once authenticated with the OAuthProxy +// Strategy priority (first legal result is used): +// - `rd` querysting parameter +// - `X-Auth-Request-Redirect` header +// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled) +// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*) +// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled) +// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*) +// - `/` +func (p *OAuthProxy) getAppRedirect(req *http.Request) (string, error) { + err := req.ParseForm() + if err != nil { + return "", err + } + + // These redirect getter functions are strategies ordered by priority + // for figuring out the redirect URL. + type redirectGetter func(req *http.Request) string + for _, rdGetter := range []redirectGetter{ + p.getRdQuerystringRedirect, + p.getXAuthRequestRedirect, + p.getXForwardedHeadersRedirect, + p.getURIRedirect, + } { + if redirect := rdGetter(req); redirect != "" { + return redirect, nil + } + } + + return "/", nil +} + +func isForwardedRequest(req *http.Request) bool { + return requestutil.IsProxied(req) && + req.Host != requestutil.GetRequestHost(req) +} + +func (p *OAuthProxy) hasProxyPrefix(path string) bool { + return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix)) +} + +// getRdQuerystringRedirect handles this getAppRedirect strategy: +// - `rd` querysting parameter +func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string { + redirect := req.Form.Get("rd") + if p.IsValidRedirect(redirect) { + return redirect + } + return "" +} + +// getXAuthRequestRedirect handles this getAppRedirect strategy: +// - `X-Auth-Request-Redirect` Header +func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string { + redirect := req.Header.Get("X-Auth-Request-Redirect") + if p.IsValidRedirect(redirect) { + return redirect + } + return "" +} + +// getXForwardedHeadersRedirect handles these getAppRedirect strategies: +// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled) +// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*) +func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string { + if !isForwardedRequest(req) { + return "" + } + + uri := requestutil.GetRequestURI(req) + if p.hasProxyPrefix(uri) { + uri = "/" + } + + redirect := fmt.Sprintf( + "%s://%s%s", + requestutil.GetRequestProto(req), + requestutil.GetRequestHost(req), + uri, + ) + + if p.IsValidRedirect(redirect) { + return redirect + } + return "" +} + +// getURIRedirect handles these getAppRedirect strategies: +// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled) +// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*) +// - `/` +func (p *OAuthProxy) getURIRedirect(req *http.Request) string { + redirect := requestutil.GetRequestURI(req) + if !p.IsValidRedirect(redirect) { + redirect = req.URL.RequestURI() + } + + if p.hasProxyPrefix(redirect) { + return "/" + } + return redirect +} + +// splitHostPort separates host and port. If the port is not valid, it returns +// the entire input as host, and it doesn't check the validity of the host. +// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric. +// *** taken from net/url, modified validOptionalPort() to accept ":*" +func splitHostPort(hostport string) (host, port string) { + host = hostport + + colon := strings.LastIndexByte(host, ':') + if colon != -1 && validOptionalPort(host[colon:]) { + host, port = host[:colon], host[colon+1:] + } + + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + host = host[1 : len(host)-1] + } + + return +} + +// validOptionalPort reports whether port is either an empty string +// or matches /^:\d*$/ +// *** taken from net/url, modified to accept ":*" +func validOptionalPort(port string) bool { + if port == "" || port == ":*" { + return true + } + if port[0] != ':' { + return false + } + for _, b := range port[1:] { + if b < '0' || b > '9' { + return false + } + } + return true } // getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so @@ -1153,7 +1155,7 @@ func extractAllowedGroups(req *http.Request) map[string]struct{} { } // addHeadersForProxying adds the appropriate headers the request / response for proxying -func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) { +func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, session *sessionsapi.SessionState) { if session.Email == "" { rw.Header().Set("GAP-Auth", session.User) } else { @@ -1181,8 +1183,8 @@ func isAjax(req *http.Request) bool { return false } -// ErrorJSON returns the error code with an application/json mime type -func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { +// errorJSON returns the error code with an application/json mime type +func (p *OAuthProxy) errorJSON(rw http.ResponseWriter, code int) { rw.Header().Set("Content-Type", applicationJSON) rw.WriteHeader(code) } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 8adea1ce..3366ef5f 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -415,8 +415,9 @@ func Test_redeemCode(t *testing.T) { t.Fatal(err) } - _, err = proxy.redeemCode(context.Background(), "www.example.com", "") - assert.Error(t, err) + req := httptest.NewRequest(http.MethodGet, "/", nil) + _, err = proxy.redeemCode(req) + assert.Equal(t, providers.ErrMissingCode, err) } func Test_enrichSession(t *testing.T) { @@ -1749,7 +1750,7 @@ func TestRequestSignature(t *testing.T) { } } -func TestGetRedirect(t *testing.T) { +func Test_getAppRedirect(t *testing.T) { opts := baseTestOptions() opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com", ".example.com:8443") err := validation.Validate(opts) @@ -1900,7 +1901,7 @@ func TestGetRedirect(t *testing.T) { req = middleware.AddRequestScope(req, &middleware.RequestScope{ ReverseProxy: tt.reverseProxy, }) - redirect, err := proxy.GetAppRedirect(req) + redirect, err := proxy.getAppRedirect(req) assert.NoError(t, err) assert.Equal(t, tt.expectedRedirect, redirect) From da02914a9c53a1db72994c7c9d41433cc32f8004 Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Sat, 9 Jan 2021 11:45:26 -0800 Subject: [PATCH 28/28] Log IsValidRedirect violations and do a final safety call --- CHANGELOG.md | 8 ++++++++ oauthproxy.go | 43 +++++++++++++++++++++++++++---------------- 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e38f0cd..5d0d8021 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ ## Important Notes +- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) Redirect URL generation will attempt secondary strategies + in the priority chain if any fail the `IsValidRedirect` security check. Previously any failures fell back to `/`. - [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Keycloak will now use `--profile-url` if set for the userinfo endpoint instead of `--validate-url`. `--validate-url` will still work for backwards compatibility. - [#957](https://github.com/oauth2-proxy/oauth2-proxy/pull/957) To use X-Forwarded-{Proto,Host,Uri} on redirect detection, `--reverse-proxy` must be `true`. @@ -36,6 +38,11 @@ ## Breaking Changes +- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) `--reverse-proxy` must be true to trust `X-Forwarded-*` headers as canonical. + These are used throughout the application in redirect URLs, cookie domains and host logging logic. These are the headers: + - `X-Forwarded-Proto` instead of `req.URL.Scheme` + - `X-Forwarded-Host` instead of `req.Host` + - `X-Forwarded-Uri` instead of `req.URL.RequestURI()` - [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) In config files & envvar configs, `keycloak_group` is now the plural `keycloak_groups`. Flag configs are still `--keycloak-group` but it can be passed multiple times. - [#911](https://github.com/oauth2-proxy/oauth2-proxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google". @@ -60,6 +67,7 @@ ## Changes since v6.1.1 - [#995](https://github.com/oauth2-proxy/oauth2-proxy/pull/995) Add Security Policy (@JoelSpeed) +- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) Require `--reverse-proxy` true to trust `X-Forwareded-*` type headers (@NickMeves) - [#970](https://github.com/oauth2-proxy/oauth2-proxy/pull/970) Fix joined cookie name for those containing underline in the suffix (@peppered) - [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Migrate Keycloak to EnrichSession & support multiple groups for authorization (@NickMeves) - [#957](https://github.com/oauth2-proxy/oauth2-proxy/pull/957) Use X-Forwarded-{Proto,Host,Uri} on redirect as last resort (@linuxgemini) diff --git a/oauthproxy.go b/oauthproxy.go index 28f667b3..36c58c46 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -955,7 +955,9 @@ func (p *OAuthProxy) getAppRedirect(req *http.Request) (string, error) { p.getXForwardedHeadersRedirect, p.getURIRedirect, } { - if redirect := rdGetter(req); redirect != "" { + redirect := rdGetter(req) + // Call `p.IsValidRedirect` again here a final time to be safe + if redirect != "" && p.IsValidRedirect(redirect) { return redirect, nil } } @@ -972,24 +974,32 @@ func (p *OAuthProxy) hasProxyPrefix(path string) bool { return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix)) } -// getRdQuerystringRedirect handles this getAppRedirect strategy: -// - `rd` querysting parameter -func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string { - redirect := req.Form.Get("rd") +func (p *OAuthProxy) validateRedirect(redirect string, errorFormat string) string { if p.IsValidRedirect(redirect) { return redirect } + if redirect != "" { + logger.Errorf(errorFormat, redirect) + } return "" } +// getRdQuerystringRedirect handles this getAppRedirect strategy: +// - `rd` querysting parameter +func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string { + return p.validateRedirect( + req.Form.Get("rd"), + "Invalid redirect provided in rd querystring parameter: %s", + ) +} + // getXAuthRequestRedirect handles this getAppRedirect strategy: // - `X-Auth-Request-Redirect` Header func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string { - redirect := req.Header.Get("X-Auth-Request-Redirect") - if p.IsValidRedirect(redirect) { - return redirect - } - return "" + return p.validateRedirect( + req.Header.Get("X-Auth-Request-Redirect"), + "Invalid redirect provided in X-Auth-Request-Redirect header: %s", + ) } // getXForwardedHeadersRedirect handles these getAppRedirect strategies: @@ -1012,10 +1022,8 @@ func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string { uri, ) - if p.IsValidRedirect(redirect) { - return redirect - } - return "" + return p.validateRedirect(redirect, + "Invalid redirect generated from X-Forwarded-* headers: %s") } // getURIRedirect handles these getAppRedirect strategies: @@ -1023,8 +1031,11 @@ func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string { // - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*) // - `/` func (p *OAuthProxy) getURIRedirect(req *http.Request) string { - redirect := requestutil.GetRequestURI(req) - if !p.IsValidRedirect(redirect) { + redirect := p.validateRedirect( + requestutil.GetRequestURI(req), + "Invalid redirect generated from X-Forwarded-Uri header: %s", + ) + if redirect == "" { redirect = req.URL.RequestURI() }