add test for JwtSession with custom AuthorizationHeaderName

Signed-off-by: dvmartinweigl <martin.weigl@datavisyn.io>
This commit is contained in:
dvmartinweigl 2025-11-12 15:04:05 +01:00
parent cf9a179276
commit ad442e92c8
6 changed files with 153 additions and 21 deletions

View File

@ -406,11 +406,11 @@ func buildSessionChain(opts *options.Options, provider providers.Provider, sessi
middlewareapi.CreateTokenToSessionFunc(verifier.Verify))
}
chain = chain.Append(middleware.NewJwtSessionLoader(sessionLoaders, opts.BearerTokenLoginFallback))
chain = chain.Append(middleware.NewJwtSessionLoader(sessionLoaders, opts.BearerTokenLoginFallback, opts.AuthorizationHeaderName))
}
if validator != nil {
chain = chain.Append(middleware.NewBasicAuthSessionLoader(validator, opts.HtpasswdUserGroups, opts.LegacyPreferEmailToUser))
chain = chain.Append(middleware.NewBasicAuthSessionLoader(validator, opts.HtpasswdUserGroups, opts.LegacyPreferEmailToUser, opts.AuthorizationHeaderName))
}
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
@ -1119,6 +1119,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
if session == nil {
return nil, ErrNeedsLogin
}
invalidEmail := session.Email != "" && !p.Validator(session.Email)

View File

@ -1920,6 +1920,127 @@ func TestGetJwtSession(t *testing.T) {
assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com")
}
func TestGetJwtSessionCustomAuthorizationHeaderName(t *testing.T) {
goodJwt := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
"eyJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjoiaHR0cHM6Ly90ZXN0Lm15YXBwLmNvbSIsIm5hbWUiOiJKb2huIERvZSIsImVtY" +
"WlsIjoiam9obkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwiaWF0IjoxNTUzNjkxMj" +
"E1LCJleHAiOjE5MTIxNTE4MjF9." +
"rLVyzOnEldUq_pNkfa-WiV8TVJYWyZCaM2Am_uo8FGg11zD7l-qmz3x1seTvqpH6Y0Ty00fmv6dJnGnC8WMnPXQiodRTfhBSe" +
"OKZMu0HkMD2sg52zlKkbfLTO6ic5VnbVgwjjrB8am_Ta6w7kyFUaB5C1BsIrrLMldkWEhynbb8"
keyset := NoOpKeySet{}
verifier := oidc.NewVerifier("https://issuer.example.com", keyset,
&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true,
SkipClientIDCheck: true})
verificationOptions := internaloidc.IDTokenVerificationOptions{
AudienceClaims: []string{"aud"},
ClientID: "https://test.myapp.com",
ExtraAudiences: []string{},
}
internalVerifier := internaloidc.NewVerifier(verifier, verificationOptions)
test, err := NewAuthOnlyEndpointTest("", func(opts *options.Options) {
opts.AuthorizationHeaderName = "Authorization-Custom"
opts.InjectRequestHeaders = []options.Header{
{
Name: "Authorization-Custom",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "id_token",
Prefix: "Bearer ",
},
},
},
},
{
Name: "X-Forwarded-User",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "user",
},
},
},
},
{
Name: "X-Forwarded-Email",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "email",
},
},
},
},
}
opts.InjectResponseHeaders = []options.Header{
{
Name: "Authorization-Custom",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "id_token",
Prefix: "Bearer ",
},
},
},
},
{
Name: "X-Auth-Request-User",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "user",
},
},
},
},
{
Name: "X-Auth-Request-Email",
Values: []options.HeaderValue{
{
ClaimSource: &options.ClaimSource{
Claim: "email",
},
},
},
},
}
opts.SkipJwtBearerTokens = true
opts.AuthorizationHeaderName = "Authorization-Custom"
opts.SetJWTBearerVerifiers(append(opts.GetJWTBearerVerifiers(), internalVerifier))
})
if err != nil {
t.Fatal(err)
}
tp, _ := test.proxy.provider.(*TestProvider)
tp.GroupValidator = func(s string) bool {
return true
}
authHeader := fmt.Sprintf("Bearer %s", goodJwt)
test.req.Header = map[string][]string{
"Authorization-Custom": {authHeader},
}
test.proxy.ServeHTTP(test.rw, test.req)
if test.rw.Code >= 400 {
t.Fatalf("expected 3xx got %d", test.rw.Code)
}
// Check PassAuthorization, should overwrite Basic header
assert.Equal(t, test.req.Header.Get("Authorization-Custom"), authHeader)
assert.Equal(t, test.req.Header.Get("X-Forwarded-User"), "1234567890")
assert.Equal(t, test.req.Header.Get("X-Forwarded-Email"), "john@example.com")
// SetAuthorization and SetXAuthRequest
assert.Equal(t, test.rw.Header().Get("Authorization-Custom"), authHeader)
assert.Equal(t, test.rw.Header().Get("X-Auth-Request-User"), "1234567890")
assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com")
}
func Test_prepareNoCache(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
prepareNoCache(w)

View File

@ -11,9 +11,12 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
)
func NewBasicAuthSessionLoader(validator basic.Validator, sessionGroups []string, preferEmail bool) alice.Constructor {
func NewBasicAuthSessionLoader(validator basic.Validator, sessionGroups []string, preferEmail bool, authorizationHeaderName string) alice.Constructor {
if authorizationHeaderName == "" {
authorizationHeaderName = "Authorization"
}
return func(next http.Handler) http.Handler {
return loadBasicAuthSession(validator, sessionGroups, preferEmail, next)
return loadBasicAuthSession(validator, sessionGroups, preferEmail, authorizationHeaderName, next)
}
}
@ -22,15 +25,17 @@ func NewBasicAuthSessionLoader(validator basic.Validator, sessionGroups []string
// If no authorization header is found, or the header is invalid, no session
// will be loaded and the request will be passed to the next handler.
// If a session was loaded by a previous handler, it will not be replaced.
func loadBasicAuthSession(validator basic.Validator, sessionGroups []string, preferEmail bool, next http.Handler) http.Handler {
func loadBasicAuthSession(validator basic.Validator, sessionGroups []string, preferEmail bool, authorizationHeaderName string, next http.Handler) http.Handler {
// This is a hack to be backwards compatible with the old PreferEmailToUser option.
// Long term we will have a rich static user configuration option and this will
// be removed.
// TODO(JoelSpeed): Remove this hack once rich static user config is implemented.
getSession := getBasicSession
getSession := func(validator basic.Validator, sessionGroups []string, req *http.Request) (*sessionsapi.SessionState, error) {
return getBasicSession(validator, sessionGroups, authorizationHeaderName, req)
}
if preferEmail {
getSession = func(validator basic.Validator, sessionGroups []string, req *http.Request) (*sessionsapi.SessionState, error) {
session, err := getBasicSession(validator, sessionGroups, req)
session, err := getBasicSession(validator, sessionGroups, authorizationHeaderName, req)
if session != nil {
session.Email = session.User
}
@ -62,8 +67,9 @@ func loadBasicAuthSession(validator basic.Validator, sessionGroups []string, pre
// getBasicSession attempts to load a basic session from the request.
// If the credentials in the request exist within the htpasswdMap,
// a new session will be created.
func getBasicSession(validator basic.Validator, sessionGroups []string, req *http.Request) (*sessionsapi.SessionState, error) {
auth := req.Header.Get("Authorization")
func getBasicSession(validator basic.Validator, sessionGroups []string, authorizationHeaderName string, req *http.Request) (*sessionsapi.SessionState, error) {
auth := req.Header.Get(authorizationHeaderName)
if auth == "" {
// No auth header provided, so don't attempt to load a session
return nil, nil

View File

@ -55,7 +55,7 @@ var _ = Describe("Basic Auth Session Suite", func() {
// Create the handler with a next handler that will capture the session
// from the scope
var gotSession *sessionsapi.SessionState
handler := NewBasicAuthSessionLoader(validator, in.sessionGroups, in.preferEmail)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler := NewBasicAuthSessionLoader(validator, in.sessionGroups, in.preferEmail, "Authorization")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotSession = middlewareapi.GetRequestScope(r).Session
}))
handler.ServeHTTP(rw, req)

View File

@ -15,11 +15,12 @@ import (
const jwtRegexFormat = `^ey[a-zA-Z0-9_-]*\.ey[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$`
func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionFunc, bearerTokenLoginFallback bool) alice.Constructor {
func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionFunc, bearerTokenLoginFallback bool, authorizationHeaderName string) alice.Constructor {
js := &jwtSessionLoader{
jwtRegex: regexp.MustCompile(jwtRegexFormat),
sessionLoaders: sessionLoaders,
denyInvalidJWTs: !bearerTokenLoginFallback,
jwtRegex: regexp.MustCompile(jwtRegexFormat),
sessionLoaders: sessionLoaders,
denyInvalidJWTs: !bearerTokenLoginFallback,
authorizationHeaderName: authorizationHeaderName,
}
return js.loadSession
}
@ -27,9 +28,10 @@ func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionFunc, bear
// jwtSessionLoader is responsible for loading sessions from JWTs in
// Authorization headers.
type jwtSessionLoader struct {
jwtRegex *regexp.Regexp
sessionLoaders []middlewareapi.TokenToSessionFunc
denyInvalidJWTs bool
jwtRegex *regexp.Regexp
sessionLoaders []middlewareapi.TokenToSessionFunc
denyInvalidJWTs bool
authorizationHeaderName string
}
// loadSession attempts to load a session from a JWT stored in an Authorization
@ -67,7 +69,8 @@ func (j *jwtSessionLoader) loadSession(next http.Handler) http.Handler {
// getJwtSession loads a session based on a JWT token in the authorization header.
// (see the config options skip-jwt-bearer-tokens, extra-jwt-issuers, and bearer-token-login-fallback)
func (j *jwtSessionLoader) getJwtSession(req *http.Request) (*sessionsapi.SessionState, error) {
auth := req.Header.Get("Authorization")
auth := req.Header.Get(j.authorizationHeaderName)
if auth == "" {
// No auth header provided, so don't attempt to load a session
return nil, nil

View File

@ -115,7 +115,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
// Create the handler with a next handler that will capture the session
// from the scope
var gotSession *sessionsapi.SessionState
handler := NewJwtSessionLoader(sessionLoaders, true)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler := NewJwtSessionLoader(sessionLoaders, true, "Authorization")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotSession = middlewareapi.GetRequestScope(r).Session
}))
handler.ServeHTTP(rw, req)
@ -185,7 +185,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
// Create the handler with a next handler that will capture the session
// from the scope
var gotSession *sessionsapi.SessionState
handler := NewJwtSessionLoader(sessionLoaders, false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler := NewJwtSessionLoader(sessionLoaders, false, "Authorization")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotSession = middlewareapi.GetRequestScope(r).Session
}))
handler.ServeHTTP(rw, req)
@ -259,7 +259,8 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
).Verify
j = &jwtSessionLoader{
jwtRegex: regexp.MustCompile(jwtRegexFormat),
jwtRegex: regexp.MustCompile(jwtRegexFormat),
authorizationHeaderName: "Authorization",
sessionLoaders: []middlewareapi.TokenToSessionFunc{
middlewareapi.CreateTokenToSessionFunc(verifier),
},