diff --git a/oauthproxy.go b/oauthproxy.go index c6db18a7..bdef853f 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -406,11 +406,11 @@ func buildSessionChain(opts *options.Options, provider providers.Provider, sessi middlewareapi.CreateTokenToSessionFunc(verifier.Verify)) } - chain = chain.Append(middleware.NewJwtSessionLoader(sessionLoaders, opts.BearerTokenLoginFallback)) + chain = chain.Append(middleware.NewJwtSessionLoader(sessionLoaders, opts.BearerTokenLoginFallback, opts.AuthorizationHeaderName)) } if validator != nil { - chain = chain.Append(middleware.NewBasicAuthSessionLoader(validator, opts.HtpasswdUserGroups, opts.LegacyPreferEmailToUser)) + chain = chain.Append(middleware.NewBasicAuthSessionLoader(validator, opts.HtpasswdUserGroups, opts.LegacyPreferEmailToUser, opts.AuthorizationHeaderName)) } chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ @@ -1119,6 +1119,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R if session == nil { return nil, ErrNeedsLogin + } invalidEmail := session.Email != "" && !p.Validator(session.Email) diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 44569a35..b0ac7a52 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1920,6 +1920,127 @@ func TestGetJwtSession(t *testing.T) { assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com") } +func TestGetJwtSessionCustomAuthorizationHeaderName(t *testing.T) { + goodJwt := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." + + "eyJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjoiaHR0cHM6Ly90ZXN0Lm15YXBwLmNvbSIsIm5hbWUiOiJKb2huIERvZSIsImVtY" + + "WlsIjoiam9obkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwiaWF0IjoxNTUzNjkxMj" + + "E1LCJleHAiOjE5MTIxNTE4MjF9." + + "rLVyzOnEldUq_pNkfa-WiV8TVJYWyZCaM2Am_uo8FGg11zD7l-qmz3x1seTvqpH6Y0Ty00fmv6dJnGnC8WMnPXQiodRTfhBSe" + + "OKZMu0HkMD2sg52zlKkbfLTO6ic5VnbVgwjjrB8am_Ta6w7kyFUaB5C1BsIrrLMldkWEhynbb8" + + keyset := NoOpKeySet{} + verifier := oidc.NewVerifier("https://issuer.example.com", keyset, + &oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true, + SkipClientIDCheck: true}) + verificationOptions := internaloidc.IDTokenVerificationOptions{ + AudienceClaims: []string{"aud"}, + ClientID: "https://test.myapp.com", + ExtraAudiences: []string{}, + } + internalVerifier := internaloidc.NewVerifier(verifier, verificationOptions) + + test, err := NewAuthOnlyEndpointTest("", func(opts *options.Options) { + opts.AuthorizationHeaderName = "Authorization-Custom" + opts.InjectRequestHeaders = []options.Header{ + { + Name: "Authorization-Custom", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + Prefix: "Bearer ", + }, + }, + }, + }, + { + Name: "X-Forwarded-User", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + }, + }, + }, + }, + { + Name: "X-Forwarded-Email", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "email", + }, + }, + }, + }, + } + + opts.InjectResponseHeaders = []options.Header{ + { + Name: "Authorization-Custom", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + Prefix: "Bearer ", + }, + }, + }, + }, + { + Name: "X-Auth-Request-User", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + }, + }, + }, + }, + { + Name: "X-Auth-Request-Email", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "email", + }, + }, + }, + }, + } + opts.SkipJwtBearerTokens = true + opts.AuthorizationHeaderName = "Authorization-Custom" + opts.SetJWTBearerVerifiers(append(opts.GetJWTBearerVerifiers(), internalVerifier)) + }) + if err != nil { + t.Fatal(err) + } + tp, _ := test.proxy.provider.(*TestProvider) + tp.GroupValidator = func(s string) bool { + return true + } + + authHeader := fmt.Sprintf("Bearer %s", goodJwt) + test.req.Header = map[string][]string{ + "Authorization-Custom": {authHeader}, + } + + test.proxy.ServeHTTP(test.rw, test.req) + if test.rw.Code >= 400 { + t.Fatalf("expected 3xx got %d", test.rw.Code) + } + + // Check PassAuthorization, should overwrite Basic header + assert.Equal(t, test.req.Header.Get("Authorization-Custom"), authHeader) + assert.Equal(t, test.req.Header.Get("X-Forwarded-User"), "1234567890") + assert.Equal(t, test.req.Header.Get("X-Forwarded-Email"), "john@example.com") + + // SetAuthorization and SetXAuthRequest + assert.Equal(t, test.rw.Header().Get("Authorization-Custom"), authHeader) + assert.Equal(t, test.rw.Header().Get("X-Auth-Request-User"), "1234567890") + assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com") +} + func Test_prepareNoCache(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { prepareNoCache(w) diff --git a/pkg/apis/options/legacy_options.go b/pkg/apis/options/legacy_options.go index e22278fa..d232ed82 100644 --- a/pkg/apis/options/legacy_options.go +++ b/pkg/apis/options/legacy_options.go @@ -39,9 +39,10 @@ func NewLegacyOptions() *LegacyOptions { }, LegacyHeaders: LegacyHeaders{ - PassBasicAuth: true, - PassUserHeaders: true, - SkipAuthStripHeaders: true, + PassBasicAuth: true, + PassUserHeaders: true, + SkipAuthStripHeaders: true, + AuthorizationHeaderName: "Authorization", }, LegacyServer: LegacyServer{ @@ -90,6 +91,8 @@ func (l *LegacyOptions) ToOptions() (*Options, error) { l.Options.LegacyPreferEmailToUser = l.LegacyHeaders.PreferEmailToUser + l.Options.AuthorizationHeaderName = l.LegacyHeaders.AuthorizationHeaderName + providers, err := l.LegacyProvider.convert() if err != nil { return nil, fmt.Errorf("error converting provider: %v", err) @@ -201,9 +204,10 @@ type LegacyHeaders struct { SetXAuthRequest bool `flag:"set-xauthrequest" cfg:"set_xauthrequest"` SetAuthorization bool `flag:"set-authorization-header" cfg:"set_authorization_header"` - PreferEmailToUser bool `flag:"prefer-email-to-user" cfg:"prefer_email_to_user"` - BasicAuthPassword string `flag:"basic-auth-password" cfg:"basic_auth_password"` - SkipAuthStripHeaders bool `flag:"skip-auth-strip-headers" cfg:"skip_auth_strip_headers"` + PreferEmailToUser bool `flag:"prefer-email-to-user" cfg:"prefer_email_to_user"` + BasicAuthPassword string `flag:"basic-auth-password" cfg:"basic_auth_password"` + AuthorizationHeaderName string `flag:"authorization-header-name" cfg:"authorization_header_name"` + SkipAuthStripHeaders bool `flag:"skip-auth-strip-headers" cfg:"skip_auth_strip_headers"` } func legacyHeadersFlagSet() *pflag.FlagSet { @@ -220,6 +224,7 @@ func legacyHeadersFlagSet() *pflag.FlagSet { flagSet.Bool("prefer-email-to-user", false, "Prefer to use the Email address as the Username when passing information to upstream. Will only use Username if Email is unavailable, eg. htaccess authentication. Used in conjunction with -pass-basic-auth and -pass-user-headers") flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header") + flagSet.String("authorization-header-name", "Authorization", "name of the authorization header to use instead of Authorization") flagSet.Bool("skip-auth-strip-headers", true, "strips X-Forwarded-* style authentication headers & Authorization header if they would be set by oauth2-proxy") return flagSet @@ -235,7 +240,7 @@ func (l *LegacyHeaders) getRequestHeaders() []Header { requestHeaders := []Header{} if l.PassBasicAuth && l.BasicAuthPassword != "" { - requestHeaders = append(requestHeaders, getBasicAuthHeader(l.PreferEmailToUser, l.BasicAuthPassword)) + requestHeaders = append(requestHeaders, getBasicAuthHeader(l.PreferEmailToUser, l.BasicAuthPassword, l.AuthorizationHeaderName)) } // In the old implementation, PassUserHeaders is a subset of PassBasicAuth @@ -249,7 +254,7 @@ func (l *LegacyHeaders) getRequestHeaders() []Header { } if l.PassAuthorization { - requestHeaders = append(requestHeaders, getAuthorizationHeader()) + requestHeaders = append(requestHeaders, getAuthorizationHeader(l.AuthorizationHeaderName)) } for i := range requestHeaders { @@ -270,24 +275,28 @@ func (l *LegacyHeaders) getResponseHeaders() []Header { } if l.SetBasicAuth { - responseHeaders = append(responseHeaders, getBasicAuthHeader(l.PreferEmailToUser, l.BasicAuthPassword)) + responseHeaders = append(responseHeaders, getBasicAuthHeader(l.PreferEmailToUser, l.BasicAuthPassword, l.AuthorizationHeaderName)) } if l.SetAuthorization { - responseHeaders = append(responseHeaders, getAuthorizationHeader()) + responseHeaders = append(responseHeaders, getAuthorizationHeader(l.AuthorizationHeaderName)) } return responseHeaders } -func getBasicAuthHeader(preferEmailToUser bool, basicAuthPassword string) Header { +func getBasicAuthHeader(preferEmailToUser bool, basicAuthPassword string, headerName string) Header { claim := "user" if preferEmailToUser { claim = "email" } + if headerName == "" { + headerName = "Authorization" + } + return Header{ - Name: "Authorization", + Name: headerName, Values: []HeaderValue{ { ClaimSource: &ClaimSource{ @@ -368,9 +377,13 @@ func getPassAccessTokenHeader() Header { } } -func getAuthorizationHeader() Header { +func getAuthorizationHeader(headerName string) Header { + if headerName == "" { + headerName = "Authorization" + } + return Header{ - Name: "Authorization", + Name: headerName, Values: []HeaderValue{ { ClaimSource: &ClaimSource{ diff --git a/pkg/apis/options/load_test.go b/pkg/apis/options/load_test.go index 06123c37..42854e19 100644 --- a/pkg/apis/options/load_test.go +++ b/pkg/apis/options/load_test.go @@ -25,9 +25,10 @@ var _ = Describe("Load", func() { }, LegacyHeaders: LegacyHeaders{ - PassBasicAuth: true, - PassUserHeaders: true, - SkipAuthStripHeaders: true, + PassBasicAuth: true, + PassUserHeaders: true, + AuthorizationHeaderName: "Authorization", + SkipAuthStripHeaders: true, }, LegacyServer: LegacyServer{ @@ -48,6 +49,7 @@ var _ = Describe("Load", func() { Options: Options{ BearerTokenLoginFallback: true, + AuthorizationHeaderName: "Authorization", ProxyPrefix: "/oauth2", PingPath: "/ping", ReadyPath: "/ready", diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index 8fa72c7c..8a8d2ea3 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -56,6 +56,7 @@ type Options struct { SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` SkipAuthRoutes []string `flag:"skip-auth-route" cfg:"skip_auth_routes"` SkipJwtBearerTokens bool `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"` + AuthorizationHeaderName string `flag:"authorization-header-name" cfg:"authorization_header_name"` BearerTokenLoginFallback bool `flag:"bearer-token-login-fallback" cfg:"bearer_token_login_fallback"` ExtraJwtIssuers []string `flag:"extra-jwt-issuers" cfg:"extra_jwt_issuers"` SkipProviderButton bool `flag:"skip-provider-button" cfg:"skip_provider_button"` @@ -99,6 +100,7 @@ func (o *Options) SetRealClientIPParser(s ipapi.RealClientIPParser) { o.re func NewOptions() *Options { return &Options{ BearerTokenLoginFallback: true, + AuthorizationHeaderName: "Authorization", ProxyPrefix: "/oauth2", Providers: providerDefaults(), PingPath: "/ping", @@ -130,6 +132,7 @@ func NewFlagSet() *pflag.FlagSet { flagSet.Bool("skip-auth-preflight", false, "will skip authentication for OPTIONS requests") flagSet.Bool("ssl-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS providers") flagSet.Bool("skip-jwt-bearer-tokens", false, "will skip requests that have verified JWT bearer tokens (default false)") + flagSet.String("authorization-header-name", "Authorization", "name of the authorization header to use instead of Authorization") flagSet.Bool("bearer-token-login-fallback", true, "if skip-jwt-bearer-tokens is set, fall back to normal login redirect with an invalid JWT. If false, 403 instead") flagSet.Bool("force-json-errors", false, "will force JSON errors instead of HTTP error pages or redirects") flagSet.Bool("encode-state", false, "will encode oauth state with base64") diff --git a/pkg/middleware/basic_session.go b/pkg/middleware/basic_session.go index 71b822c0..6effb00c 100644 --- a/pkg/middleware/basic_session.go +++ b/pkg/middleware/basic_session.go @@ -11,9 +11,12 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" ) -func NewBasicAuthSessionLoader(validator basic.Validator, sessionGroups []string, preferEmail bool) alice.Constructor { +func NewBasicAuthSessionLoader(validator basic.Validator, sessionGroups []string, preferEmail bool, authorizationHeaderName string) alice.Constructor { + if authorizationHeaderName == "" { + authorizationHeaderName = "Authorization" + } return func(next http.Handler) http.Handler { - return loadBasicAuthSession(validator, sessionGroups, preferEmail, next) + return loadBasicAuthSession(validator, sessionGroups, preferEmail, authorizationHeaderName, next) } } @@ -22,15 +25,17 @@ func NewBasicAuthSessionLoader(validator basic.Validator, sessionGroups []string // If no authorization header is found, or the header is invalid, no session // will be loaded and the request will be passed to the next handler. // If a session was loaded by a previous handler, it will not be replaced. -func loadBasicAuthSession(validator basic.Validator, sessionGroups []string, preferEmail bool, next http.Handler) http.Handler { +func loadBasicAuthSession(validator basic.Validator, sessionGroups []string, preferEmail bool, authorizationHeaderName string, next http.Handler) http.Handler { // This is a hack to be backwards compatible with the old PreferEmailToUser option. // Long term we will have a rich static user configuration option and this will // be removed. // TODO(JoelSpeed): Remove this hack once rich static user config is implemented. - getSession := getBasicSession + getSession := func(validator basic.Validator, sessionGroups []string, req *http.Request) (*sessionsapi.SessionState, error) { + return getBasicSession(validator, sessionGroups, authorizationHeaderName, req) + } if preferEmail { getSession = func(validator basic.Validator, sessionGroups []string, req *http.Request) (*sessionsapi.SessionState, error) { - session, err := getBasicSession(validator, sessionGroups, req) + session, err := getBasicSession(validator, sessionGroups, authorizationHeaderName, req) if session != nil { session.Email = session.User } @@ -62,8 +67,9 @@ func loadBasicAuthSession(validator basic.Validator, sessionGroups []string, pre // getBasicSession attempts to load a basic session from the request. // If the credentials in the request exist within the htpasswdMap, // a new session will be created. -func getBasicSession(validator basic.Validator, sessionGroups []string, req *http.Request) (*sessionsapi.SessionState, error) { - auth := req.Header.Get("Authorization") +func getBasicSession(validator basic.Validator, sessionGroups []string, authorizationHeaderName string, req *http.Request) (*sessionsapi.SessionState, error) { + auth := req.Header.Get(authorizationHeaderName) + if auth == "" { // No auth header provided, so don't attempt to load a session return nil, nil diff --git a/pkg/middleware/basic_session_test.go b/pkg/middleware/basic_session_test.go index 61f3ecbc..fca238c1 100644 --- a/pkg/middleware/basic_session_test.go +++ b/pkg/middleware/basic_session_test.go @@ -55,7 +55,7 @@ var _ = Describe("Basic Auth Session Suite", func() { // Create the handler with a next handler that will capture the session // from the scope var gotSession *sessionsapi.SessionState - handler := NewBasicAuthSessionLoader(validator, in.sessionGroups, in.preferEmail)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := NewBasicAuthSessionLoader(validator, in.sessionGroups, in.preferEmail, "Authorization")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotSession = middlewareapi.GetRequestScope(r).Session })) handler.ServeHTTP(rw, req) diff --git a/pkg/middleware/jwt_session.go b/pkg/middleware/jwt_session.go index 790eb8b2..d8f4d613 100644 --- a/pkg/middleware/jwt_session.go +++ b/pkg/middleware/jwt_session.go @@ -15,11 +15,12 @@ import ( const jwtRegexFormat = `^ey[a-zA-Z0-9_-]*\.ey[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$` -func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionFunc, bearerTokenLoginFallback bool) alice.Constructor { +func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionFunc, bearerTokenLoginFallback bool, authorizationHeaderName string) alice.Constructor { js := &jwtSessionLoader{ - jwtRegex: regexp.MustCompile(jwtRegexFormat), - sessionLoaders: sessionLoaders, - denyInvalidJWTs: !bearerTokenLoginFallback, + jwtRegex: regexp.MustCompile(jwtRegexFormat), + sessionLoaders: sessionLoaders, + denyInvalidJWTs: !bearerTokenLoginFallback, + authorizationHeaderName: authorizationHeaderName, } return js.loadSession } @@ -27,9 +28,10 @@ func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionFunc, bear // jwtSessionLoader is responsible for loading sessions from JWTs in // Authorization headers. type jwtSessionLoader struct { - jwtRegex *regexp.Regexp - sessionLoaders []middlewareapi.TokenToSessionFunc - denyInvalidJWTs bool + jwtRegex *regexp.Regexp + sessionLoaders []middlewareapi.TokenToSessionFunc + denyInvalidJWTs bool + authorizationHeaderName string } // loadSession attempts to load a session from a JWT stored in an Authorization @@ -67,7 +69,8 @@ func (j *jwtSessionLoader) loadSession(next http.Handler) http.Handler { // getJwtSession loads a session based on a JWT token in the authorization header. // (see the config options skip-jwt-bearer-tokens, extra-jwt-issuers, and bearer-token-login-fallback) func (j *jwtSessionLoader) getJwtSession(req *http.Request) (*sessionsapi.SessionState, error) { - auth := req.Header.Get("Authorization") + auth := req.Header.Get(j.authorizationHeaderName) + if auth == "" { // No auth header provided, so don't attempt to load a session return nil, nil diff --git a/pkg/middleware/jwt_session_test.go b/pkg/middleware/jwt_session_test.go index 12f30f5c..7c5abaee 100644 --- a/pkg/middleware/jwt_session_test.go +++ b/pkg/middleware/jwt_session_test.go @@ -115,7 +115,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` // Create the handler with a next handler that will capture the session // from the scope var gotSession *sessionsapi.SessionState - handler := NewJwtSessionLoader(sessionLoaders, true)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := NewJwtSessionLoader(sessionLoaders, true, "Authorization")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotSession = middlewareapi.GetRequestScope(r).Session })) handler.ServeHTTP(rw, req) @@ -185,7 +185,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` // Create the handler with a next handler that will capture the session // from the scope var gotSession *sessionsapi.SessionState - handler := NewJwtSessionLoader(sessionLoaders, false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := NewJwtSessionLoader(sessionLoaders, false, "Authorization")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotSession = middlewareapi.GetRequestScope(r).Session })) handler.ServeHTTP(rw, req) @@ -259,7 +259,8 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` ).Verify j = &jwtSessionLoader{ - jwtRegex: regexp.MustCompile(jwtRegexFormat), + jwtRegex: regexp.MustCompile(jwtRegexFormat), + authorizationHeaderName: "Authorization", sessionLoaders: []middlewareapi.TokenToSessionFunc{ middlewareapi.CreateTokenToSessionFunc(verifier), },