From 8059a812cdd88024eb37d2df6e0aed6cf5a118ca Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Wed, 29 Jul 2020 20:10:14 +0100 Subject: [PATCH] Integrate new header injectors with OAuth2 Proxy --- main.go | 30 +- oauthproxy.go | 228 ++++-------- oauthproxy_test.go | 621 ++++++++++++++++---------------- pkg/validation/options.go | 8 - pkg/validation/options_test.go | 23 -- pkg/validation/sessions.go | 25 +- pkg/validation/sessions_test.go | 101 +++++- 7 files changed, 485 insertions(+), 551 deletions(-) diff --git a/main.go b/main.go index 8cd3ee5b..151c8331 100644 --- a/main.go +++ b/main.go @@ -3,17 +3,14 @@ package main import ( "fmt" "math/rand" - "net" "os" "os/signal" "runtime" "syscall" "time" - "github.com/justinas/alice" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation" ) @@ -63,33 +60,8 @@ func main() { rand.Seed(time.Now().UnixNano()) - chain := alice.New() - - if opts.ForceHTTPS { - _, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress) - if err != nil { - logger.Fatalf("FATAL: invalid HTTPS address %q: %v", opts.HTTPAddress, err) - } - chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort)) - } - - healthCheckPaths := []string{opts.PingPath} - healthCheckUserAgents := []string{opts.PingUserAgent} - if opts.GCPHealthChecks { - healthCheckPaths = append(healthCheckPaths, "/liveness_check", "/readiness_check") - healthCheckUserAgents = append(healthCheckUserAgents, "GoogleHC/1.0") - } - - // To silence logging of health checks, register the health check handler before - // the logging handler - if opts.Logging.SilencePing { - chain = chain.Append(middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), LoggingHandler) - } else { - chain = chain.Append(LoggingHandler, middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents)) - } - s := &Server{ - Handler: chain.Then(oauthproxy), + Handler: oauthproxy, Opts: opts, stop: make(chan struct{}, 1), } diff --git a/oauthproxy.go b/oauthproxy.go index 4fd03a40..c349172a 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -2,7 +2,6 @@ package main import ( "context" - b64 "encoding/base64" "encoding/json" "errors" "fmt" @@ -98,7 +97,6 @@ type OAuthProxy struct { PassAuthorization bool PreferEmailToUser bool skipAuthPreflight bool - skipAuthStripHeaders bool skipJwtBearerTokens bool mainJwtBearerVerifier *oidc.IDTokenVerifier extraJwtBearerVerifiers []*oidc.IDTokenVerifier @@ -110,6 +108,8 @@ type OAuthProxy struct { AllowedGroups []string sessionChain alice.Chain + headersChain alice.Chain + preAuthChain alice.Chain } // NewOAuthProxy creates a new instance of OAuthProxy from the options provided @@ -169,7 +169,15 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr return nil, err } + preAuthChain, err := buildPreAuthChain(opts) + if err != nil { + return nil, fmt.Errorf("could not build pre-auth chain: %v", err) + } sessionChain := buildSessionChain(opts, sessionStore, basicAuthValidator) + headersChain, err := buildHeadersChain(opts) + if err != nil { + return nil, fmt.Errorf("could not build headers chain: %v", err) + } return &OAuthProxy{ CookieName: opts.Cookie.Name, @@ -201,20 +209,10 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr allowedRoutes: allowedRoutes, whitelistDomains: opts.WhitelistDomains, skipAuthPreflight: opts.SkipAuthPreflight, - skipAuthStripHeaders: opts.SkipAuthStripHeaders, skipJwtBearerTokens: opts.SkipJwtBearerTokens, mainJwtBearerVerifier: opts.GetOIDCVerifier(), extraJwtBearerVerifiers: opts.GetJWTBearerVerifiers(), realClientIPParser: opts.GetRealClientIPParser(), - SetXAuthRequest: opts.SetXAuthRequest, - PassBasicAuth: opts.PassBasicAuth, - SetBasicAuth: opts.SetBasicAuth, - PassUserHeaders: opts.PassUserHeaders, - BasicAuthPassword: opts.BasicAuthPassword, - PassAccessToken: opts.PassAccessToken, - SetAuthorization: opts.SetAuthorization, - PassAuthorization: opts.PassAuthorization, - PreferEmailToUser: opts.PreferEmailToUser, SkipProviderButton: opts.SkipProviderButton, templates: templates, trustedIPs: trustedIPs, @@ -226,12 +224,46 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr basicAuthValidator: basicAuthValidator, displayHtpasswdForm: basicAuthValidator != nil, sessionChain: sessionChain, + headersChain: headersChain, + preAuthChain: preAuthChain, }, nil } -func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain { +// buildPreAuthChain constructs a chain that should process every request before +// the OAuth2 Proxy authentication logic kicks in. +// For example forcing HTTPS or health checks. +func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { chain := alice.New(middleware.NewScope()) + if opts.ForceHTTPS { + _, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress) + if err != nil { + return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.HTTPAddress, err) + } + chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort)) + } + + healthCheckPaths := []string{opts.PingPath} + healthCheckUserAgents := []string{opts.PingUserAgent} + if opts.GCPHealthChecks { + healthCheckPaths = append(healthCheckPaths, "/liveness_check", "/readiness_check") + healthCheckUserAgents = append(healthCheckUserAgents, "GoogleHC/1.0") + } + + // To silence logging of health checks, register the health check handler before + // the logging handler + if opts.Logging.SilencePing { + chain = chain.Append(middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), LoggingHandler) + } else { + chain = chain.Append(LoggingHandler, middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents)) + } + + return chain, nil +} + +func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain { + chain := alice.New() + if opts.SkipJwtBearerTokens { sessionLoaders := []middlewareapi.TokenToSessionLoader{} if opts.GetOIDCVerifier() != nil { @@ -264,6 +296,20 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt return chain } +func buildHeadersChain(opts *options.Options) (alice.Chain, error) { + requestInjector, err := middleware.NewRequestHeaderInjector(opts.InjectRequestHeaders) + if err != nil { + return alice.Chain{}, fmt.Errorf("error constructing request header injector: %v", err) + } + + responseInjector, err := middleware.NewResponseHeaderInjector(opts.InjectResponseHeaders) + if err != nil { + return alice.Chain{}, fmt.Errorf("error constructing request header injector: %v", err) + } + + return alice.New(requestInjector, responseInjector), nil +} + func buildSignInMessage(opts *options.Options) string { var msg string if len(opts.Banner) >= 1 { @@ -685,6 +731,10 @@ func (p *OAuthProxy) IsTrustedIP(req *http.Request) bool { } func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req) +} + +func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) { prepareNoCache(rw) } @@ -884,15 +934,14 @@ func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) // we are authenticated p.addHeadersForProxying(rw, req, session) - rw.WriteHeader(http.StatusAccepted) + p.headersChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusAccepted) + })).ServeHTTP(rw, req) } // SkipAuthProxy proxies allowlisted requests and skips authentication func (p *OAuthProxy) SkipAuthProxy(rw http.ResponseWriter, req *http.Request) { - if p.skipAuthStripHeaders { - p.stripAuthHeaders(req) - } - p.serveMux.ServeHTTP(rw, req) + p.headersChain.Then(p.serveMux).ServeHTTP(rw, req) } // Proxy proxies the user request if the user is authenticated else it prompts @@ -903,8 +952,7 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { case nil: // we are authenticated p.addHeadersForProxying(rw, req, session) - p.serveMux.ServeHTTP(rw, req) - + p.headersChain.Then(p.serveMux).ServeHTTP(rw, req) case ErrNeedsLogin: // we need to send the user to a login screen if isAjax(req) { @@ -961,120 +1009,6 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R // addHeadersForProxying adds the appropriate headers the request / response for proxying func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) { - if p.PassBasicAuth { - if p.PreferEmailToUser && session.Email != "" { - req.SetBasicAuth(session.Email, p.BasicAuthPassword) - req.Header["X-Forwarded-User"] = []string{session.Email} - req.Header.Del("X-Forwarded-Email") - } else { - req.SetBasicAuth(session.User, p.BasicAuthPassword) - req.Header["X-Forwarded-User"] = []string{session.User} - if session.Email != "" { - req.Header["X-Forwarded-Email"] = []string{session.Email} - } else { - req.Header.Del("X-Forwarded-Email") - } - } - if session.PreferredUsername != "" { - req.Header["X-Forwarded-Preferred-Username"] = []string{session.PreferredUsername} - } else { - req.Header.Del("X-Forwarded-Preferred-Username") - } - } - - if p.PassUserHeaders { - if p.PreferEmailToUser && session.Email != "" { - req.Header["X-Forwarded-User"] = []string{session.Email} - req.Header.Del("X-Forwarded-Email") - } else { - req.Header["X-Forwarded-User"] = []string{session.User} - if session.Email != "" { - req.Header["X-Forwarded-Email"] = []string{session.Email} - } else { - req.Header.Del("X-Forwarded-Email") - } - } - - if session.PreferredUsername != "" { - req.Header["X-Forwarded-Preferred-Username"] = []string{session.PreferredUsername} - } else { - req.Header.Del("X-Forwarded-Preferred-Username") - } - - if len(session.Groups) > 0 { - for _, group := range session.Groups { - req.Header.Add("X-Forwarded-Groups", group) - } - } else { - req.Header.Del("X-Forwarded-Groups") - } - } - - if p.SetXAuthRequest { - rw.Header().Set("X-Auth-Request-User", session.User) - if session.Email != "" { - rw.Header().Set("X-Auth-Request-Email", session.Email) - } else { - rw.Header().Del("X-Auth-Request-Email") - } - if session.PreferredUsername != "" { - rw.Header().Set("X-Auth-Request-Preferred-Username", session.PreferredUsername) - } else { - rw.Header().Del("X-Auth-Request-Preferred-Username") - } - - if p.PassAccessToken { - if session.AccessToken != "" { - rw.Header().Set("X-Auth-Request-Access-Token", session.AccessToken) - } else { - rw.Header().Del("X-Auth-Request-Access-Token") - } - } - - if len(session.Groups) > 0 { - for _, group := range session.Groups { - rw.Header().Add("X-Auth-Request-Groups", group) - } - } else { - rw.Header().Del("X-Auth-Request-Groups") - } - } - - if p.PassAccessToken { - if session.AccessToken != "" { - req.Header["X-Forwarded-Access-Token"] = []string{session.AccessToken} - } else { - req.Header.Del("X-Forwarded-Access-Token") - } - } - - if p.PassAuthorization { - if session.IDToken != "" { - req.Header["Authorization"] = []string{fmt.Sprintf("Bearer %s", session.IDToken)} - } else { - req.Header.Del("Authorization") - } - } - if p.SetBasicAuth { - switch { - case p.PreferEmailToUser && session.Email != "": - authVal := b64.StdEncoding.EncodeToString([]byte(session.Email + ":" + p.BasicAuthPassword)) - rw.Header().Set("Authorization", "Basic "+authVal) - case session.User != "": - authVal := b64.StdEncoding.EncodeToString([]byte(session.User + ":" + p.BasicAuthPassword)) - rw.Header().Set("Authorization", "Basic "+authVal) - default: - rw.Header().Del("Authorization") - } - } - if p.SetAuthorization { - if session.IDToken != "" { - rw.Header().Set("Authorization", fmt.Sprintf("Bearer %s", session.IDToken)) - } else { - rw.Header().Del("Authorization") - } - } - if session.Email == "" { rw.Header().Set("GAP-Auth", session.User) } else { @@ -1082,32 +1016,6 @@ func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Req } } -// stripAuthHeaders removes Auth headers for allowlisted routes from skipAuthRegex -func (p *OAuthProxy) stripAuthHeaders(req *http.Request) { - if p.PassBasicAuth { - req.Header.Del("X-Forwarded-User") - req.Header.Del("X-Forwarded-Groups") - req.Header.Del("X-Forwarded-Email") - req.Header.Del("X-Forwarded-Preferred-Username") - req.Header.Del("Authorization") - } - - if p.PassUserHeaders { - req.Header.Del("X-Forwarded-User") - req.Header.Del("X-Forwarded-Groups") - req.Header.Del("X-Forwarded-Email") - req.Header.Del("X-Forwarded-Preferred-Username") - } - - if p.PassAccessToken { - req.Header.Del("X-Forwarded-Access-Token") - } - - if p.PassAuthorization { - req.Header.Del("Authorization") - } -} - // isAjax checks if a request is an ajax request func isAjax(req *http.Request) bool { acceptValues := req.Header.Values("Accept") diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 46ee24b8..1736b39f 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -495,6 +495,8 @@ func TestBasicAuthPassword(t *testing.T) { t.Fatal(err) } })) + + basicAuthPassword := "This is a secure password" opts := baseTestOptions() opts.UpstreamServers = options.Upstreams{ { @@ -505,11 +507,22 @@ func TestBasicAuthPassword(t *testing.T) { } opts.Cookie.Secure = false - opts.PassBasicAuth = true - opts.SetBasicAuth = true - opts.PassUserHeaders = true - opts.PreferEmailToUser = true - opts.BasicAuthPassword = "This is a secure password" + opts.InjectRequestHeaders = []options.Header{ + { + Name: "Authorization", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "email", + BasicAuthPassword: &options.SecretSource{ + Value: []byte(base64.StdEncoding.EncodeToString([]byte(basicAuthPassword))), + }, + }, + }, + }, + }, + } + err := validation.Validate(opts) assert.NoError(t, err) @@ -524,148 +537,44 @@ func TestBasicAuthPassword(t *testing.T) { t.Fatal(err) } + // Save the required session rw := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", strings.NewReader("")) - req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now())) - proxy.ServeHTTP(rw, req) - if rw.Code >= 400 { - t.Fatalf("expected 3xx got %d", rw.Code) - } - cookie := rw.Header().Values("Set-Cookie")[1] - - cookieName := proxy.CookieName - var value string - keyPrefix := cookieName + "=" - - for _, field := range strings.Split(cookie, "; ") { - value = strings.TrimPrefix(field, keyPrefix) - if value != field { - break - } else { - value = "" - } - } - - req, _ = http.NewRequest("GET", "/", strings.NewReader("")) - req.AddCookie(&http.Cookie{ - Name: cookieName, - Value: value, - Path: "/", - Expires: time.Now().Add(time.Duration(24)), - HttpOnly: true, + req, _ := http.NewRequest("GET", "/", nil) + err = proxy.sessionStore.Save(rw, req, &sessions.SessionState{ + Email: emailAddress, }) - req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now())) + assert.NoError(t, err) + // Extract the cookie value to inject into the test request + cookie := rw.Header().Values("Set-Cookie")[0] + + req, _ = http.NewRequest("GET", "/", nil) + req.Header.Set("Cookie", cookie) rw = httptest.NewRecorder() proxy.ServeHTTP(rw, req) // The username in the basic auth credentials is expected to be equal to the email address from the // auth response, so we use the same variable here. - expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(emailAddress+":"+opts.BasicAuthPassword)) + expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(emailAddress+":"+basicAuthPassword)) assert.Equal(t, expectedHeader, rw.Body.String()) providerServer.Close() } -func TestBasicAuthWithEmail(t *testing.T) { - opts := baseTestOptions() - opts.PassBasicAuth = true - opts.PassUserHeaders = false - opts.PreferEmailToUser = false - opts.BasicAuthPassword = "This is a secure password" - err := validation.Validate(opts) - assert.NoError(t, err) - - const emailAddress = "john.doe@example.com" - const userName = "9fcab5c9b889a557" - - // The username in the basic auth credentials is expected to be equal to the email address from the - expectedEmailHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(emailAddress+":"+opts.BasicAuthPassword)) - expectedUserHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(userName+":"+opts.BasicAuthPassword)) - - created := time.Now() - session := &sessions.SessionState{ - User: userName, - Email: emailAddress, - AccessToken: "oauth_token", - CreatedAt: &created, - } - { - rw := httptest.NewRecorder() - req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase0", nil) - proxy, err := NewOAuthProxy(opts, func(email string) bool { - return email == emailAddress - }) - if err != nil { - t.Fatal(err) - } - proxy.addHeadersForProxying(rw, req, session) - assert.Equal(t, expectedUserHeader, req.Header["Authorization"][0]) - assert.Equal(t, userName, req.Header["X-Forwarded-User"][0]) - } - - opts.PreferEmailToUser = true - { - rw := httptest.NewRecorder() - req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase1", nil) - - proxy, err := NewOAuthProxy(opts, func(email string) bool { - return email == emailAddress - }) - if err != nil { - t.Fatal(err) - } - proxy.addHeadersForProxying(rw, req, session) - assert.Equal(t, expectedEmailHeader, req.Header["Authorization"][0]) - assert.Equal(t, emailAddress, req.Header["X-Forwarded-User"][0]) - } -} - -func TestPassUserHeadersWithEmail(t *testing.T) { - opts := baseTestOptions() - err := validation.Validate(opts) - assert.NoError(t, err) - - const emailAddress = "john.doe@example.com" - const userName = "9fcab5c9b889a557" - - created := time.Now() - session := &sessions.SessionState{ - User: userName, - Email: emailAddress, - AccessToken: "oauth_token", - CreatedAt: &created, - } - { - rw := httptest.NewRecorder() - req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase0", nil) - proxy, err := NewOAuthProxy(opts, func(email string) bool { - return email == emailAddress - }) - if err != nil { - t.Fatal(err) - } - proxy.addHeadersForProxying(rw, req, session) - assert.Equal(t, userName, req.Header["X-Forwarded-User"][0]) - } - - opts.PreferEmailToUser = true - { - rw := httptest.NewRecorder() - req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase1", nil) - - proxy, err := NewOAuthProxy(opts, func(email string) bool { - return email == emailAddress - }) - if err != nil { - t.Fatal(err) - } - proxy.addHeadersForProxying(rw, req, session) - assert.Equal(t, emailAddress, req.Header["X-Forwarded-User"][0]) - } -} - func TestPassGroupsHeadersWithGroups(t *testing.T) { opts := baseTestOptions() + opts.InjectRequestHeaders = []options.Header{ + { + Name: "X-Forwarded-Groups", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "groups", + }, + }, + }, + }, + } + err := validation.Validate(opts) assert.NoError(t, err) @@ -681,161 +590,27 @@ func TestPassGroupsHeadersWithGroups(t *testing.T) { AccessToken: "oauth_token", CreatedAt: &created, } - { - rw := httptest.NewRecorder() - req, _ := http.NewRequest("GET", opts.ProxyPrefix+"/testCase0", nil) - proxy, err := NewOAuthProxy(opts, func(email string) bool { - return email == emailAddress - }) - if err != nil { - t.Fatal(err) - } - proxy.addHeadersForProxying(rw, req, session) - assert.Equal(t, groups, req.Header["X-Forwarded-Groups"]) - } -} -func TestStripAuthHeaders(t *testing.T) { - testCases := map[string]struct { - SkipAuthStripHeaders bool - PassBasicAuth bool - PassUserHeaders bool - PassAccessToken bool - PassAuthorization bool - StrippedHeaders map[string]bool - }{ - "Default options": { - SkipAuthStripHeaders: true, - PassBasicAuth: true, - PassUserHeaders: true, - PassAccessToken: false, - PassAuthorization: false, - StrippedHeaders: map[string]bool{ - "X-Forwarded-User": true, - "X-Forwared-Groups": true, - "X-Forwarded-Email": true, - "X-Forwarded-Preferred-Username": true, - "X-Forwarded-Access-Token": false, - "Authorization": true, - }, - }, - "Pass access token": { - SkipAuthStripHeaders: true, - PassBasicAuth: true, - PassUserHeaders: true, - PassAccessToken: true, - PassAuthorization: false, - StrippedHeaders: map[string]bool{ - "X-Forwarded-User": true, - "X-Forwared-Groups": true, - "X-Forwarded-Email": true, - "X-Forwarded-Preferred-Username": true, - "X-Forwarded-Access-Token": true, - "Authorization": true, - }, - }, - "Nothing setting Authorization": { - SkipAuthStripHeaders: true, - PassBasicAuth: false, - PassUserHeaders: true, - PassAccessToken: true, - PassAuthorization: false, - StrippedHeaders: map[string]bool{ - "X-Forwarded-User": true, - "X-Forwared-Groups": true, - "X-Forwarded-Email": true, - "X-Forwarded-Preferred-Username": true, - "X-Forwarded-Access-Token": true, - "Authorization": false, - }, - }, - "Only Authorization header modified": { - SkipAuthStripHeaders: true, - PassBasicAuth: false, - PassUserHeaders: false, - PassAccessToken: false, - PassAuthorization: true, - StrippedHeaders: map[string]bool{ - "X-Forwarded-User": false, - "X-Forwared-Groups": false, - "X-Forwarded-Email": false, - "X-Forwarded-Preferred-Username": false, - "X-Forwarded-Access-Token": false, - "Authorization": true, - }, - }, - "Don't strip any headers (default options)": { - SkipAuthStripHeaders: false, - PassBasicAuth: true, - PassUserHeaders: true, - PassAccessToken: false, - PassAuthorization: false, - StrippedHeaders: map[string]bool{ - "X-Forwarded-User": false, - "X-Forwared-Groups": false, - "X-Forwarded-Email": false, - "X-Forwarded-Preferred-Username": false, - "X-Forwarded-Access-Token": false, - "Authorization": false, - }, - }, - "Don't strip any headers (custom options)": { - SkipAuthStripHeaders: false, - PassBasicAuth: true, - PassUserHeaders: true, - PassAccessToken: true, - PassAuthorization: false, - StrippedHeaders: map[string]bool{ - "X-Forwarded-User": false, - "X-Forwared-Groups": false, - "X-Forwarded-Email": false, - "X-Forwarded-Preferred-Username": false, - "X-Forwarded-Access-Token": false, - "Authorization": false, - }, - }, - } + proxy, err := NewOAuthProxy(opts, func(email string) bool { + return email == emailAddress + }) + assert.NoError(t, err) - initialHeaders := map[string]string{ - "X-Forwarded-User": "9fcab5c9b889a557", - "X-Forwarded-Email": "john.doe@example.com", - "X-Forwarded-Groups": "a,b,c", - "X-Forwarded-Preferred-Username": "john.doe", - "X-Forwarded-Access-Token": "AccessToken", - "Authorization": "bearer IDToken", - } + // Save the required session + rw := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + err = proxy.sessionStore.Save(rw, req, session) + assert.NoError(t, err) - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - opts := baseTestOptions() - opts.SkipAuthStripHeaders = tc.SkipAuthStripHeaders - opts.PassBasicAuth = tc.PassBasicAuth - opts.PassUserHeaders = tc.PassUserHeaders - opts.PassAccessToken = tc.PassAccessToken - opts.PassAuthorization = tc.PassAuthorization - err := validation.Validate(opts) - assert.NoError(t, err) + // Extract the cookie value to inject into the test request + cookie := rw.Header().Values("Set-Cookie")[0] - req, _ := http.NewRequest("GET", fmt.Sprintf("%s/testCase", opts.ProxyPrefix), nil) - for header, val := range initialHeaders { - req.Header.Set(header, val) - } + req, _ = http.NewRequest("GET", "/", nil) + req.Header.Set("Cookie", cookie) + rw = httptest.NewRecorder() + proxy.ServeHTTP(rw, req) - proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) - assert.NoError(t, err) - if proxy.skipAuthStripHeaders { - proxy.stripAuthHeaders(req) - } - - for header, stripped := range tc.StrippedHeaders { - if stripped { - assert.Equal(t, req.Header.Get(header), "") - } else { - assert.Equal(t, req.Header.Get(header), initialHeaders[header]) - } - } - }) - } + assert.Equal(t, groups, req.Header["X-Forwarded-Groups"]) } type PassAccessTokenTest struct { @@ -884,7 +659,21 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTe } patt.opts.Cookie.Secure = false - patt.opts.PassAccessToken = opts.PassAccessToken + if opts.PassAccessToken { + patt.opts.InjectRequestHeaders = []options.Header{ + { + Name: "X-Forwarded-Access-Token", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "access_token", + }, + }, + }, + }, + } + } + err := validation.Validate(patt.opts) if err != nil { return nil, err @@ -1442,7 +1231,48 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { var pcTest ProcessCookieTest pcTest.opts = baseTestOptions() - pcTest.opts.SetXAuthRequest = true + pcTest.opts.InjectResponseHeaders = []options.Header{ + { + Name: "X-Auth-Request-User", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + }, + }, + }, + }, + { + Name: "X-Auth-Request-Email", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "email", + }, + }, + }, + }, + { + Name: "X-Auth-Request-Groups", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "groups", + }, + }, + }, + }, + { + Name: "X-Forwarded-Preferred-Username", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "preferred_username", + }, + }, + }, + }, + } pcTest.opts.AllowedGroups = []string{"oauth_groups"} err := validation.Validate(pcTest.opts) assert.NoError(t, err) @@ -1480,8 +1310,62 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { var pcTest ProcessCookieTest pcTest.opts = baseTestOptions() - pcTest.opts.SetXAuthRequest = true - pcTest.opts.SetBasicAuth = true + pcTest.opts.InjectResponseHeaders = []options.Header{ + { + Name: "X-Auth-Request-User", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + }, + }, + }, + }, + { + Name: "X-Auth-Request-Email", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "email", + }, + }, + }, + }, + { + Name: "X-Auth-Request-Groups", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "groups", + }, + }, + }, + }, + { + Name: "X-Forwarded-Preferred-Username", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "preferred_username", + }, + }, + }, + }, + { + Name: "Authorization", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + BasicAuthPassword: &options.SecretSource{ + Value: []byte(base64.StdEncoding.EncodeToString([]byte("This is a secure password"))), + }, + }, + }, + }, + }, + } + err := validation.Validate(pcTest.opts) assert.NoError(t, err) @@ -1511,7 +1395,7 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) assert.Equal(t, "oauth_user", pcTest.rw.Header().Values("X-Auth-Request-User")[0]) assert.Equal(t, "oauth_user@example.com", pcTest.rw.Header().Values("X-Auth-Request-Email")[0]) - expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte("oauth_user:"+pcTest.opts.BasicAuthPassword)) + expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte("oauth_user:This is a secure password")) assert.Equal(t, expectedHeader, pcTest.rw.Header().Values("Authorization")[0]) } @@ -1519,8 +1403,48 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { var pcTest ProcessCookieTest pcTest.opts = baseTestOptions() - pcTest.opts.SetXAuthRequest = true - pcTest.opts.SetBasicAuth = false + pcTest.opts.InjectResponseHeaders = []options.Header{ + { + Name: "X-Auth-Request-User", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + }, + }, + }, + }, + { + Name: "X-Auth-Request-Email", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "email", + }, + }, + }, + }, + { + Name: "X-Auth-Request-Groups", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "groups", + }, + }, + }, + }, + { + Name: "X-Forwarded-Preferred-Username", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "preferred_username", + }, + }, + }, + }, + } err := validation.Validate(pcTest.opts) assert.NoError(t, err) @@ -1985,9 +1909,74 @@ func TestGetJwtSession(t *testing.T) { &oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) { - opts.PassAuthorization = true - opts.SetAuthorization = true - opts.SetXAuthRequest = true + opts.InjectRequestHeaders = []options.Header{ + { + Name: "Authorization", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + Prefix: "Bearer ", + }, + }, + }, + }, + { + Name: "X-Forwarded-User", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + }, + }, + }, + }, + { + Name: "X-Forwarded-Email", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "email", + }, + }, + }, + }, + } + + opts.InjectResponseHeaders = []options.Header{ + { + Name: "Authorization", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + Prefix: "Bearer ", + }, + }, + }, + }, + { + Name: "X-Auth-Request-User", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + }, + }, + }, + }, + { + Name: "X-Auth-Request-Email", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "email", + }, + }, + }, + }, + } + opts.SkipJwtBearerTokens = true opts.SetJWTBearerVerifiers(append(opts.GetJWTBearerVerifiers(), verifier)) }) @@ -2004,15 +1993,6 @@ func TestGetJwtSession(t *testing.T) { "Authorization": {authHeader}, } - // Bearer - expires := time.Unix(1912151821, 0) - session, err := test.proxy.getAuthenticatedSession(test.rw, test.req) - assert.NoError(t, err) - assert.Equal(t, session.User, "1234567890") - assert.Equal(t, session.Email, "john@example.com") - assert.Equal(t, session.ExpiresOn, &expires) - assert.Equal(t, session.IDToken, goodJwt) - test.proxy.ServeHTTP(test.rw, test.req) if test.rw.Code >= 400 { t.Fatalf("expected 3xx got %d", test.rw.Code) @@ -2140,6 +2120,43 @@ func baseTestOptions() *options.Options { opts.ClientID = clientID opts.ClientSecret = clientSecret opts.EmailDomains = []string{"*"} + + // Default injected headers for legacy configuration + opts.InjectRequestHeaders = []options.Header{ + { + Name: "Authorization", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + BasicAuthPassword: &options.SecretSource{ + Value: []byte(base64.StdEncoding.EncodeToString([]byte("This is a secure password"))), + }, + }, + }, + }, + }, + { + Name: "X-Forwarded-User", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "user", + }, + }, + }, + }, + { + Name: "X-Forwarded-Email", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "email", + }, + }, + }, + }, + } return opts } diff --git a/pkg/validation/options.go b/pkg/validation/options.go index 29d8f71e..16a19612 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -73,10 +73,6 @@ func Validate(o *options.Options) error { "\n use email-domain=* to authorize all email addresses") } - if o.SetBasicAuth && o.SetAuthorization { - msgs = append(msgs, "mutually exclusive: set-basic-auth and set-authorization-header can not both be true") - } - if o.OIDCIssuerURL != "" { ctx := context.Background() @@ -161,10 +157,6 @@ func Validate(o *options.Options) error { } } - if o.PreferEmailToUser && !o.PassBasicAuth && !o.PassUserHeaders { - msgs = append(msgs, "PreferEmailToUser should only be used with PassBasicAuth or PassUserHeaders") - } - if o.SkipJwtBearerTokens { // Configure extra issuers if len(o.ExtraJwtIssuers) > 0 { diff --git a/pkg/validation/options_test.go b/pkg/validation/options_test.go index f88ef7af..e19c6991 100644 --- a/pkg/validation/options_test.go +++ b/pkg/validation/options_test.go @@ -162,29 +162,6 @@ func TestDefaultProviderApiSettings(t *testing.T) { assert.Equal(t, "profile email", p.Scope) } -func TestPassAccessTokenRequiresSpecificCookieSecretLengths(t *testing.T) { - o := testOptions() - assert.Equal(t, nil, Validate(o)) - - assert.Equal(t, false, o.PassAccessToken) - o.PassAccessToken = true - o.Cookie.Secret = "cookie of invalid length-" - assert.NotEqual(t, nil, Validate(o)) - - o.PassAccessToken = false - o.Cookie.Refresh = time.Duration(24) * time.Hour - assert.NotEqual(t, nil, Validate(o)) - - o.Cookie.Secret = "16 bytes AES-128" - assert.Equal(t, nil, Validate(o)) - - o.Cookie.Secret = "24 byte secret AES-192--" - assert.Equal(t, nil, Validate(o)) - - o.Cookie.Secret = "32 byte secret for AES-256------" - assert.Equal(t, nil, Validate(o)) -} - func TestCookieRefreshMustBeLessThanCookieExpire(t *testing.T) { o := testOptions() assert.Equal(t, nil, Validate(o)) diff --git a/pkg/validation/sessions.go b/pkg/validation/sessions.go index 8cacd48b..48d4042a 100644 --- a/pkg/validation/sessions.go +++ b/pkg/validation/sessions.go @@ -16,18 +16,21 @@ func validateSessionCookieMinimal(o *options.Options) []string { } msgs := []string{} - if o.PassAuthorization { - msgs = append(msgs, - "pass_authorization_header requires oauth tokens in sessions. session_cookie_minimal cannot be set") - } - if o.SetAuthorization { - msgs = append(msgs, - "set_authorization_header requires oauth tokens in sessions. session_cookie_minimal cannot be set") - } - if o.PassAccessToken { - msgs = append(msgs, - "pass_access_token requires oauth tokens in sessions. session_cookie_minimal cannot be set") + for _, header := range append(o.InjectRequestHeaders, o.InjectResponseHeaders...) { + for _, value := range header.Values { + if value.ClaimSource != nil { + if value.ClaimSource.Claim == "access_token" { + msgs = append(msgs, + fmt.Sprintf("access_token claim for header %q requires oauth tokens in sessions. session_cookie_minimal cannot be set", header.Name)) + } + if value.ClaimSource.Claim == "id_token" { + msgs = append(msgs, + fmt.Sprintf("id_token claim for header %q requires oauth tokens in sessions. session_cookie_minimal cannot be set", header.Name)) + } + } + } } + if o.Cookie.Refresh != time.Duration(0) { msgs = append(msgs, "cookie_refresh > 0 requires oauth tokens in sessions. session_cookie_minimal cannot be set") diff --git a/pkg/validation/sessions_test.go b/pkg/validation/sessions_test.go index f6463431..f1943288 100644 --- a/pkg/validation/sessions_test.go +++ b/pkg/validation/sessions_test.go @@ -13,10 +13,9 @@ import ( var _ = Describe("Sessions", func() { const ( - passAuthorizationMsg = "pass_authorization_header requires oauth tokens in sessions. session_cookie_minimal cannot be set" - setAuthorizationMsg = "set_authorization_header requires oauth tokens in sessions. session_cookie_minimal cannot be set" - passAccessTokenMsg = "pass_access_token requires oauth tokens in sessions. session_cookie_minimal cannot be set" - cookieRefreshMsg = "cookie_refresh > 0 requires oauth tokens in sessions. session_cookie_minimal cannot be set" + idTokenConflictMsg = "id_token claim for header \"X-ID-Token\" requires oauth tokens in sessions. session_cookie_minimal cannot be set" + accessTokenConflictMsg = "access_token claim for header \"X-Access-Token\" requires oauth tokens in sessions. session_cookie_minimal cannot be set" + cookieRefreshMsg = "cookie_refresh > 0 requires oauth tokens in sessions. session_cookie_minimal cannot be set" ) type cookieMinimalTableInput struct { @@ -38,14 +37,25 @@ var _ = Describe("Sessions", func() { }, errStrings: []string{}, }), - Entry("No minimal cookie session & passAuthorization", &cookieMinimalTableInput{ + Entry("No minimal cookie session & request header has access_token claim", &cookieMinimalTableInput{ opts: &options.Options{ Session: options.SessionOptions{ Cookie: options.CookieStoreOptions{ Minimal: false, }, }, - PassAuthorization: true, + InjectRequestHeaders: []options.Header{ + { + Name: "X-Access-Token", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "access_token", + }, + }, + }, + }, + }, }, errStrings: []string{}, }), @@ -59,38 +69,71 @@ var _ = Describe("Sessions", func() { }, errStrings: []string{}, }), - Entry("PassAuthorization conflict", &cookieMinimalTableInput{ + Entry("Request Header id_token conflict", &cookieMinimalTableInput{ opts: &options.Options{ Session: options.SessionOptions{ Cookie: options.CookieStoreOptions{ Minimal: true, }, }, - PassAuthorization: true, + InjectRequestHeaders: []options.Header{ + { + Name: "X-ID-Token", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, }, - errStrings: []string{passAuthorizationMsg}, + errStrings: []string{idTokenConflictMsg}, }), - Entry("SetAuthorization conflict", &cookieMinimalTableInput{ + Entry("Response Header id_token conflict", &cookieMinimalTableInput{ opts: &options.Options{ Session: options.SessionOptions{ Cookie: options.CookieStoreOptions{ Minimal: true, }, }, - SetAuthorization: true, + InjectResponseHeaders: []options.Header{ + { + Name: "X-ID-Token", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, }, - errStrings: []string{setAuthorizationMsg}, + errStrings: []string{idTokenConflictMsg}, }), - Entry("PassAccessToken conflict", &cookieMinimalTableInput{ + Entry("Request Header access_token conflict", &cookieMinimalTableInput{ opts: &options.Options{ Session: options.SessionOptions{ Cookie: options.CookieStoreOptions{ Minimal: true, }, }, - PassAccessToken: true, + InjectRequestHeaders: []options.Header{ + { + Name: "X-Access-Token", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "access_token", + }, + }, + }, + }, + }, }, - errStrings: []string{passAccessTokenMsg}, + errStrings: []string{accessTokenConflictMsg}, }), Entry("CookieRefresh conflict", &cookieMinimalTableInput{ opts: &options.Options{ @@ -112,10 +155,32 @@ var _ = Describe("Sessions", func() { Minimal: true, }, }, - PassAuthorization: true, - PassAccessToken: true, + InjectResponseHeaders: []options.Header{ + { + Name: "X-ID-Token", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "id_token", + }, + }, + }, + }, + }, + InjectRequestHeaders: []options.Header{ + { + Name: "X-Access-Token", + Values: []options.HeaderValue{ + { + ClaimSource: &options.ClaimSource{ + Claim: "access_token", + }, + }, + }, + }, + }, }, - errStrings: []string{passAuthorizationMsg, passAccessTokenMsg}, + errStrings: []string{idTokenConflictMsg, accessTokenConflictMsg}, }), )