From 1d6721f7ba94b5556cb7e07b93a31489a84bb7ec Mon Sep 17 00:00:00 2001 From: Pascal Date: Fri, 16 Jan 2026 20:30:16 +0100 Subject: [PATCH] fix: WebSocket proxy to respect PassHostHeader setting (#3290) * Fix WebSocket proxy to respect PassHostHeader setting When PassHostHeader is set to false, the regular HTTP proxy correctly sets the Host header to the upstream backend URL. However, the WebSocket proxy was not respecting this setting, causing WebSocket connections to fail when backend services validate the Host header. This commit: - Adds passHostHeader parameter to newWebSocketReverseProxy() - Applies setProxyUpstreamHostHeader() when PassHostHeader=false - Ensures consistent behavior between HTTP and WebSocket proxies Fixes #3288 Signed-off-by: Pascal Schmiel * chore(): add tests, update changelog Signed-off-by: Pascal Schmiel --------- Signed-off-by: Pascal Schmiel --- CHANGELOG.md | 1 + pkg/upstream/http.go | 9 ++++-- pkg/upstream/http_test.go | 49 ++++++++++++++++++++++++++--- pkg/upstream/upstream_suite_test.go | 2 ++ 4 files changed, 55 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1467837f..d0a800f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ## Changes since v7.13.0 +- [#3290](https://github.com/oauth2-proxy/oauth2-proxy/pull/3290) fix: WebSocket proxy to respect PassHostHeader setting (@UnsignedLong) - [#3197](https://github.com/oauth2-proxy/oauth2-proxy/pull/3197) fix: NewRemoteKeySet is not using DefaultHTTPClient (@rsrdesarrollo / @tuunit) - [#3292](https://github.com/oauth2-proxy/oauth2-proxy/pull/3292) chore(deps): upgrade gomod and bump to golang v1.25.5 (@tuunit) - [#3304](https://github.com/oauth2-proxy/oauth2-proxy/pull/3304) fix: added conditional so default is not always set and env vars are honored fixes 3303 (@pixeldrew) diff --git a/pkg/upstream/http.go b/pkg/upstream/http.go index e8283144..b3c3ae74 100644 --- a/pkg/upstream/http.go +++ b/pkg/upstream/http.go @@ -55,7 +55,7 @@ func newHTTPUpstreamProxy(upstream options.Upstream, u *url.URL, sigData *option // Set up a WebSocket proxy if required var wsProxy http.Handler if ptr.Deref(upstream.ProxyWebSockets, options.DefaultUpstreamProxyWebSockets) { - wsProxy = newWebSocketReverseProxy(u, upstream.InsecureSkipTLSVerify) + wsProxy = newWebSocketReverseProxy(u, upstream.InsecureSkipTLSVerify, upstream.PassHostHeader) } var auth hmacauth.HmacAuth @@ -201,7 +201,7 @@ func setProxyDirector(proxy *httputil.ReverseProxy) { } // newWebSocketReverseProxy creates a new reverse proxy for proxying websocket connections. -func newWebSocketReverseProxy(u *url.URL, skipTLSVerify *bool) http.Handler { +func newWebSocketReverseProxy(u *url.URL, skipTLSVerify *bool, passHostHeader *bool) http.Handler { wsProxy := httputil.NewSingleHostReverseProxy(u) // Inherit default transport options from Go's stdlib @@ -215,5 +215,10 @@ func newWebSocketReverseProxy(u *url.URL, skipTLSVerify *bool) http.Handler { // Apply the customized transport to our proxy before returning it wsProxy.Transport = transport + // Set upstream host header if PassHostHeader is false (same as regular HTTP proxy) + if !ptr.Deref(passHostHeader, options.DefaultUpstreamPassHostHeader) { + setProxyUpstreamHostHeader(wsProxy, u) + } + return wsProxy } diff --git a/pkg/upstream/http_test.go b/pkg/upstream/http_test.go index 79dc0e4a..a01d5c09 100644 --- a/pkg/upstream/http_test.go +++ b/pkg/upstream/http_test.go @@ -519,10 +519,11 @@ var _ = Describe("HTTP Upstream Suite", func() { Expect(websocket.Message.Send(ws, []byte(message))).To(Succeed()) var response testWebSocketResponse Expect(websocket.JSON.Receive(ws, &response)).To(Succeed()) - Expect(response).To(Equal(testWebSocketResponse{ - Message: message, - Origin: origin, - })) + + // When PassHostHeader=true (default), the Host should be the client's original request host + Expect(response.Message).To(Equal(message)) + Expect(response.Origin).To(Equal(origin)) + Expect(response.Host).To(Equal(proxyURL.Host)) }) It("will proxy HTTP requests", func() { @@ -530,5 +531,45 @@ var _ = Describe("HTTP Upstream Suite", func() { Expect(err).ToNot(HaveOccurred()) Expect(response.StatusCode).To(Equal(200)) }) + + It("will proxy websockets respecting PassHostHeader=false", func() { + // Create a new proxy server with PassHostHeader=false + flush := 1 * time.Second + timeout := options.DefaultUpstreamTimeout + upstream := options.Upstream{ + ID: "websocketProxyNoPassHost", + PassHostHeader: ptr.To(false), + ProxyWebSockets: ptr.To(true), + InsecureSkipTLSVerify: ptr.To(false), + FlushInterval: &flush, + Timeout: &timeout, + } + + u, err := url.Parse(serverAddr) + Expect(err).ToNot(HaveOccurred()) + + handler := newHTTPUpstreamProxy(upstream, u, nil, nil) + noPassHostServer := httptest.NewServer(middleware.NewScope(false, "X-Request-Id")(handler)) + defer noPassHostServer.Close() + + origin := "http://example.localhost" + message := "Hello, world!" + + proxyURL, err := url.Parse(fmt.Sprintf("http://%s", noPassHostServer.Listener.Addr().String())) + Expect(err).ToNot(HaveOccurred()) + + wsAddr := fmt.Sprintf("ws://%s/", proxyURL.Host) + ws, err := websocket.Dial(wsAddr, "", origin) + Expect(err).ToNot(HaveOccurred()) + + Expect(websocket.Message.Send(ws, []byte(message))).To(Succeed()) + var response testWebSocketResponse + Expect(websocket.JSON.Receive(ws, &response)).To(Succeed()) + + // When PassHostHeader=false, the Host should be the upstream server address + Expect(response.Host).To(Equal(u.Host)) + Expect(response.Message).To(Equal(message)) + Expect(response.Origin).To(Equal(origin)) + }) }) }) diff --git a/pkg/upstream/upstream_suite_test.go b/pkg/upstream/upstream_suite_test.go index f34c8eaf..56aa98b1 100644 --- a/pkg/upstream/upstream_suite_test.go +++ b/pkg/upstream/upstream_suite_test.go @@ -96,6 +96,7 @@ type testHTTPRequest struct { type testWebSocketResponse struct { Message string Origin string + Host string } type testHTTPUpstream struct{} @@ -138,6 +139,7 @@ func (t *testHTTPUpstream) websocketHandler() http.Handler { wsResponse := testWebSocketResponse{ Message: string(data), Origin: ws.Request().Header.Get("Origin"), + Host: ws.Request().Host, } err = websocket.JSON.Send(ws, wsResponse) if err != nil {