Add header injector middlewares
This commit is contained in:
		
							parent
							
								
									ca5fd3c39f
								
							
						
					
					
						commit
						0a26a80337
					
				|  | @ -0,0 +1,102 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/justinas/alice" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/header" | ||||
| ) | ||||
| 
 | ||||
| func NewRequestHeaderInjector(headers []options.Header) (alice.Constructor, error) { | ||||
| 	headerInjector, err := newRequestHeaderInjector(headers) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error building request header injector: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	strip := newStripHeaders(headers) | ||||
| 	if strip != nil { | ||||
| 		return alice.New(strip, headerInjector).Then, nil | ||||
| 	} | ||||
| 	return headerInjector, nil | ||||
| } | ||||
| 
 | ||||
| func newStripHeaders(headers []options.Header) alice.Constructor { | ||||
| 	headersToStrip := []string{} | ||||
| 	for _, header := range headers { | ||||
| 		if !header.PreserveRequestValue { | ||||
| 			headersToStrip = append(headersToStrip, header.Name) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if len(headersToStrip) == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	return func(next http.Handler) http.Handler { | ||||
| 		return stripHeaders(headersToStrip, next) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func stripHeaders(headers []string, next http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		for _, header := range headers { | ||||
| 			req.Header.Del(header) | ||||
| 		} | ||||
| 		next.ServeHTTP(rw, req) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func newRequestHeaderInjector(headers []options.Header) (alice.Constructor, error) { | ||||
| 	injector, err := header.NewInjector(headers) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error building request injector: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return func(next http.Handler) http.Handler { | ||||
| 		return injectRequestHeaders(injector, next) | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| func injectRequestHeaders(injector header.Injector, next http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		scope := GetRequestScope(req) | ||||
| 
 | ||||
| 		// If scope is nil, this will panic.
 | ||||
| 		// A scope should always be injected before this handler is called.
 | ||||
| 		injector.Inject(req.Header, scope.Session) | ||||
| 		next.ServeHTTP(rw, req) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func NewResponseHeaderInjector(headers []options.Header) (alice.Constructor, error) { | ||||
| 	headerInjector, err := newResponseHeaderInjector(headers) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error building response header injector: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return headerInjector, nil | ||||
| } | ||||
| 
 | ||||
| func newResponseHeaderInjector(headers []options.Header) (alice.Constructor, error) { | ||||
| 	injector, err := header.NewInjector(headers) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error building response injector: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return func(next http.Handler) http.Handler { | ||||
| 		return injectResponseHeaders(injector, next) | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| func injectResponseHeaders(injector header.Injector, next http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		scope := GetRequestScope(req) | ||||
| 
 | ||||
| 		// If scope is nil, this will panic.
 | ||||
| 		// A scope should always be injected before this handler is called.
 | ||||
| 		injector.Inject(rw.Header(), scope.Session) | ||||
| 		next.ServeHTTP(rw, req) | ||||
| 	}) | ||||
| } | ||||
|  | @ -0,0 +1,405 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/base64" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 
 | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| var _ = Describe("Headers Suite", func() { | ||||
| 	type headersTableInput struct { | ||||
| 		headers         []options.Header | ||||
| 		initialHeaders  http.Header | ||||
| 		session         *sessionsapi.SessionState | ||||
| 		expectedHeaders http.Header | ||||
| 		expectedErr     string | ||||
| 	} | ||||
| 
 | ||||
| 	DescribeTable("the request header injector", | ||||
| 		func(in headersTableInput) { | ||||
| 			scope := &middlewareapi.RequestScope{ | ||||
| 				Session: in.session, | ||||
| 			} | ||||
| 
 | ||||
| 			// Set up the request with a request scope
 | ||||
| 			req := httptest.NewRequest("", "/", nil) | ||||
| 			contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) | ||||
| 			req = req.WithContext(contextWithScope) | ||||
| 			req.Header = in.initialHeaders.Clone() | ||||
| 
 | ||||
| 			rw := httptest.NewRecorder() | ||||
| 
 | ||||
| 			// Create the handler with a next handler that will capture the headers
 | ||||
| 			// from the request
 | ||||
| 			var gotHeaders http.Header | ||||
| 			injector, err := NewRequestHeaderInjector(in.headers) | ||||
| 			if in.expectedErr != "" { | ||||
| 				Expect(err).To(MatchError(in.expectedErr)) | ||||
| 				return | ||||
| 			} | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			handler := injector(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 				gotHeaders = r.Header.Clone() | ||||
| 			})) | ||||
| 			handler.ServeHTTP(rw, req) | ||||
| 
 | ||||
| 			Expect(gotHeaders).To(Equal(in.expectedHeaders)) | ||||
| 		}, | ||||
| 		Entry("with no configured headers", headersTableInput{ | ||||
| 			headers: []options.Header{}, | ||||
| 			initialHeaders: http.Header{ | ||||
| 				"foo": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			session: &sessionsapi.SessionState{}, | ||||
| 			expectedHeaders: http.Header{ | ||||
| 				"foo": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			expectedErr: "", | ||||
| 		}), | ||||
| 		Entry("with a claim valued header", headersTableInput{ | ||||
| 			headers: []options.Header{ | ||||
| 				{ | ||||
| 					Name: "Claim", | ||||
| 					Values: []options.HeaderValue{ | ||||
| 						{ | ||||
| 							ClaimSource: &options.ClaimSource{ | ||||
| 								Claim: "id_token", | ||||
| 							}, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			initialHeaders: http.Header{ | ||||
| 				"foo": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			session: &sessionsapi.SessionState{ | ||||
| 				IDToken: "IDToken-1234", | ||||
| 			}, | ||||
| 			expectedHeaders: http.Header{ | ||||
| 				"foo":   []string{"bar", "baz"}, | ||||
| 				"Claim": []string{"IDToken-1234"}, | ||||
| 			}, | ||||
| 			expectedErr: "", | ||||
| 		}), | ||||
| 		Entry("with a claim valued header (without preservation)", headersTableInput{ | ||||
| 			headers: []options.Header{ | ||||
| 				{ | ||||
| 					Name: "Claim", | ||||
| 					Values: []options.HeaderValue{ | ||||
| 						{ | ||||
| 							ClaimSource: &options.ClaimSource{ | ||||
| 								Claim: "id_token", | ||||
| 							}, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			initialHeaders: http.Header{ | ||||
| 				"Claim": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			session: &sessionsapi.SessionState{ | ||||
| 				IDToken: "IDToken-1234", | ||||
| 			}, | ||||
| 			expectedHeaders: http.Header{ | ||||
| 				"Claim": []string{"IDToken-1234"}, | ||||
| 			}, | ||||
| 			expectedErr: "", | ||||
| 		}), | ||||
| 		Entry("with a claim valued header (with preservation)", headersTableInput{ | ||||
| 			headers: []options.Header{ | ||||
| 				{ | ||||
| 					Name:                 "Claim", | ||||
| 					PreserveRequestValue: true, | ||||
| 					Values: []options.HeaderValue{ | ||||
| 						{ | ||||
| 							ClaimSource: &options.ClaimSource{ | ||||
| 								Claim: "id_token", | ||||
| 							}, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			initialHeaders: http.Header{ | ||||
| 				"Claim": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			session: &sessionsapi.SessionState{ | ||||
| 				IDToken: "IDToken-1234", | ||||
| 			}, | ||||
| 			expectedHeaders: http.Header{ | ||||
| 				"Claim": []string{"bar", "baz", "IDToken-1234"}, | ||||
| 			}, | ||||
| 			expectedErr: "", | ||||
| 		}), | ||||
| 		Entry("with a claim valued header that's not present (without preservation)", headersTableInput{ | ||||
| 			headers: []options.Header{ | ||||
| 				{ | ||||
| 					Name: "Claim", | ||||
| 					Values: []options.HeaderValue{ | ||||
| 						{ | ||||
| 							ClaimSource: &options.ClaimSource{ | ||||
| 								Claim: "id_token", | ||||
| 							}, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			initialHeaders: http.Header{ | ||||
| 				"Claim": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			session:         nil, | ||||
| 			expectedHeaders: http.Header{}, | ||||
| 			expectedErr:     "", | ||||
| 		}), | ||||
| 		Entry("with a claim valued header that's not present (with preservation)", headersTableInput{ | ||||
| 			headers: []options.Header{ | ||||
| 				{ | ||||
| 					Name:                 "Claim", | ||||
| 					PreserveRequestValue: true, | ||||
| 					Values: []options.HeaderValue{ | ||||
| 						{ | ||||
| 							ClaimSource: &options.ClaimSource{ | ||||
| 								Claim: "id_token", | ||||
| 							}, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			initialHeaders: http.Header{ | ||||
| 				"Claim": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			session: nil, | ||||
| 			expectedHeaders: http.Header{ | ||||
| 				"Claim": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			expectedErr: "", | ||||
| 		}), | ||||
| 		Entry("with an invalid basicAuthPassword claim valued header", headersTableInput{ | ||||
| 			headers: []options.Header{ | ||||
| 				{ | ||||
| 					Name: "X-Auth-Request-Authorization", | ||||
| 					Values: []options.HeaderValue{ | ||||
| 						{ | ||||
| 							ClaimSource: &options.ClaimSource{ | ||||
| 								Claim: "user", | ||||
| 								BasicAuthPassword: &options.SecretSource{ | ||||
| 									Value:   []byte(base64.StdEncoding.EncodeToString([]byte("basic-password"))), | ||||
| 									FromEnv: "SECRET_ENV", | ||||
| 								}, | ||||
| 							}, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			initialHeaders: http.Header{ | ||||
| 				"foo": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			session: &sessionsapi.SessionState{ | ||||
| 				User: "user-123", | ||||
| 			}, | ||||
| 			expectedHeaders: nil, | ||||
| 			expectedErr:     "error building request header injector: error building request injector: error building injector for header \"X-Auth-Request-Authorization\": error loading basicAuthPassword: secret source is invalid: exactly one entry required, specify either value, fromEnv or fromFile", | ||||
| 		}), | ||||
| 	) | ||||
| 
 | ||||
| 	DescribeTable("the response header injector", | ||||
| 		func(in headersTableInput) { | ||||
| 			scope := &middlewareapi.RequestScope{ | ||||
| 				Session: in.session, | ||||
| 			} | ||||
| 
 | ||||
| 			// Set up the request with a request scope
 | ||||
| 			req := httptest.NewRequest("", "/", nil) | ||||
| 			contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) | ||||
| 			req = req.WithContext(contextWithScope) | ||||
| 
 | ||||
| 			rw := httptest.NewRecorder() | ||||
| 			for key, values := range in.initialHeaders { | ||||
| 				for _, value := range values { | ||||
| 					rw.Header().Add(key, value) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			// Create the handler with a next handler that will capture the headers
 | ||||
| 			// from the request
 | ||||
| 			var gotHeaders http.Header | ||||
| 			injector, err := NewResponseHeaderInjector(in.headers) | ||||
| 			if in.expectedErr != "" { | ||||
| 				Expect(err).To(MatchError(in.expectedErr)) | ||||
| 				return | ||||
| 			} | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			handler := injector(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 				gotHeaders = w.Header().Clone() | ||||
| 			})) | ||||
| 			handler.ServeHTTP(rw, req) | ||||
| 
 | ||||
| 			Expect(gotHeaders).To(Equal(in.expectedHeaders)) | ||||
| 		}, | ||||
| 		Entry("with no configured headers", headersTableInput{ | ||||
| 			headers: []options.Header{}, | ||||
| 			initialHeaders: http.Header{ | ||||
| 				"Foo": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			session: &sessionsapi.SessionState{}, | ||||
| 			expectedHeaders: http.Header{ | ||||
| 				"Foo": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			expectedErr: "", | ||||
| 		}), | ||||
| 		Entry("with a claim valued header", headersTableInput{ | ||||
| 			headers: []options.Header{ | ||||
| 				{ | ||||
| 					Name: "Claim", | ||||
| 					Values: []options.HeaderValue{ | ||||
| 						{ | ||||
| 							ClaimSource: &options.ClaimSource{ | ||||
| 								Claim: "id_token", | ||||
| 							}, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			initialHeaders: http.Header{ | ||||
| 				"Foo": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			session: &sessionsapi.SessionState{ | ||||
| 				IDToken: "IDToken-1234", | ||||
| 			}, | ||||
| 			expectedHeaders: http.Header{ | ||||
| 				"Foo":   []string{"bar", "baz"}, | ||||
| 				"Claim": []string{"IDToken-1234"}, | ||||
| 			}, | ||||
| 			expectedErr: "", | ||||
| 		}), | ||||
| 		Entry("with a claim valued header (without preservation)", headersTableInput{ | ||||
| 			headers: []options.Header{ | ||||
| 				{ | ||||
| 					Name: "Claim", | ||||
| 					Values: []options.HeaderValue{ | ||||
| 						{ | ||||
| 							ClaimSource: &options.ClaimSource{ | ||||
| 								Claim: "id_token", | ||||
| 							}, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			initialHeaders: http.Header{ | ||||
| 				"Claim": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			session: &sessionsapi.SessionState{ | ||||
| 				IDToken: "IDToken-1234", | ||||
| 			}, | ||||
| 			expectedHeaders: http.Header{ | ||||
| 				"Claim": []string{"bar", "baz", "IDToken-1234"}, | ||||
| 			}, | ||||
| 			expectedErr: "", | ||||
| 		}), | ||||
| 		Entry("with a claim valued header (with preservation)", headersTableInput{ | ||||
| 			headers: []options.Header{ | ||||
| 				{ | ||||
| 					Name:                 "Claim", | ||||
| 					PreserveRequestValue: true, | ||||
| 					Values: []options.HeaderValue{ | ||||
| 						{ | ||||
| 							ClaimSource: &options.ClaimSource{ | ||||
| 								Claim: "id_token", | ||||
| 							}, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			initialHeaders: http.Header{ | ||||
| 				"Claim": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			session: &sessionsapi.SessionState{ | ||||
| 				IDToken: "IDToken-1234", | ||||
| 			}, | ||||
| 			expectedHeaders: http.Header{ | ||||
| 				"Claim": []string{"bar", "baz", "IDToken-1234"}, | ||||
| 			}, | ||||
| 			expectedErr: "", | ||||
| 		}), | ||||
| 		Entry("with a claim valued header that's not present (without preservation)", headersTableInput{ | ||||
| 			headers: []options.Header{ | ||||
| 				{ | ||||
| 					Name: "Claim", | ||||
| 					Values: []options.HeaderValue{ | ||||
| 						{ | ||||
| 							ClaimSource: &options.ClaimSource{ | ||||
| 								Claim: "id_token", | ||||
| 							}, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			initialHeaders: http.Header{ | ||||
| 				"Claim": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			session: nil, | ||||
| 			expectedHeaders: http.Header{ | ||||
| 				"Claim": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			expectedErr: "", | ||||
| 		}), | ||||
| 		Entry("with a claim valued header that's not present (with preservation)", headersTableInput{ | ||||
| 			headers: []options.Header{ | ||||
| 				{ | ||||
| 					Name:                 "Claim", | ||||
| 					PreserveRequestValue: true, | ||||
| 					Values: []options.HeaderValue{ | ||||
| 						{ | ||||
| 							ClaimSource: &options.ClaimSource{ | ||||
| 								Claim: "id_token", | ||||
| 							}, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			initialHeaders: http.Header{ | ||||
| 				"Claim": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			session: nil, | ||||
| 			expectedHeaders: http.Header{ | ||||
| 				"Claim": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			expectedErr: "", | ||||
| 		}), | ||||
| 		Entry("with an invalid basicAuthPassword claim valued header", headersTableInput{ | ||||
| 			headers: []options.Header{ | ||||
| 				{ | ||||
| 					Name: "X-Auth-Request-Authorization", | ||||
| 					Values: []options.HeaderValue{ | ||||
| 						{ | ||||
| 							ClaimSource: &options.ClaimSource{ | ||||
| 								Claim: "user", | ||||
| 								BasicAuthPassword: &options.SecretSource{ | ||||
| 									Value:   []byte(base64.StdEncoding.EncodeToString([]byte("basic-password"))), | ||||
| 									FromEnv: "SECRET_ENV", | ||||
| 								}, | ||||
| 							}, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			initialHeaders: http.Header{ | ||||
| 				"foo": []string{"bar", "baz"}, | ||||
| 			}, | ||||
| 			session: &sessionsapi.SessionState{ | ||||
| 				User: "user-123", | ||||
| 			}, | ||||
| 			expectedHeaders: nil, | ||||
| 			expectedErr:     "error building response header injector: error building response injector: error building injector for header \"X-Auth-Request-Authorization\": error loading basicAuthPassword: secret source is invalid: exactly one entry required, specify either value, fromEnv or fromFile", | ||||
| 		}), | ||||
| 	) | ||||
| }) | ||||
		Loading…
	
		Reference in New Issue