diff --git a/CHANGELOG.md b/CHANGELOG.md index 1467837f..d0a800f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ## Changes since v7.13.0 +- [#3290](https://github.com/oauth2-proxy/oauth2-proxy/pull/3290) fix: WebSocket proxy to respect PassHostHeader setting (@UnsignedLong) - [#3197](https://github.com/oauth2-proxy/oauth2-proxy/pull/3197) fix: NewRemoteKeySet is not using DefaultHTTPClient (@rsrdesarrollo / @tuunit) - [#3292](https://github.com/oauth2-proxy/oauth2-proxy/pull/3292) chore(deps): upgrade gomod and bump to golang v1.25.5 (@tuunit) - [#3304](https://github.com/oauth2-proxy/oauth2-proxy/pull/3304) fix: added conditional so default is not always set and env vars are honored fixes 3303 (@pixeldrew) diff --git a/pkg/upstream/http.go b/pkg/upstream/http.go index e8283144..b3c3ae74 100644 --- a/pkg/upstream/http.go +++ b/pkg/upstream/http.go @@ -55,7 +55,7 @@ func newHTTPUpstreamProxy(upstream options.Upstream, u *url.URL, sigData *option // Set up a WebSocket proxy if required var wsProxy http.Handler if ptr.Deref(upstream.ProxyWebSockets, options.DefaultUpstreamProxyWebSockets) { - wsProxy = newWebSocketReverseProxy(u, upstream.InsecureSkipTLSVerify) + wsProxy = newWebSocketReverseProxy(u, upstream.InsecureSkipTLSVerify, upstream.PassHostHeader) } var auth hmacauth.HmacAuth @@ -201,7 +201,7 @@ func setProxyDirector(proxy *httputil.ReverseProxy) { } // newWebSocketReverseProxy creates a new reverse proxy for proxying websocket connections. -func newWebSocketReverseProxy(u *url.URL, skipTLSVerify *bool) http.Handler { +func newWebSocketReverseProxy(u *url.URL, skipTLSVerify *bool, passHostHeader *bool) http.Handler { wsProxy := httputil.NewSingleHostReverseProxy(u) // Inherit default transport options from Go's stdlib @@ -215,5 +215,10 @@ func newWebSocketReverseProxy(u *url.URL, skipTLSVerify *bool) http.Handler { // Apply the customized transport to our proxy before returning it wsProxy.Transport = transport + // Set upstream host header if PassHostHeader is false (same as regular HTTP proxy) + if !ptr.Deref(passHostHeader, options.DefaultUpstreamPassHostHeader) { + setProxyUpstreamHostHeader(wsProxy, u) + } + return wsProxy } diff --git a/pkg/upstream/http_test.go b/pkg/upstream/http_test.go index 79dc0e4a..a01d5c09 100644 --- a/pkg/upstream/http_test.go +++ b/pkg/upstream/http_test.go @@ -519,10 +519,11 @@ var _ = Describe("HTTP Upstream Suite", func() { Expect(websocket.Message.Send(ws, []byte(message))).To(Succeed()) var response testWebSocketResponse Expect(websocket.JSON.Receive(ws, &response)).To(Succeed()) - Expect(response).To(Equal(testWebSocketResponse{ - Message: message, - Origin: origin, - })) + + // When PassHostHeader=true (default), the Host should be the client's original request host + Expect(response.Message).To(Equal(message)) + Expect(response.Origin).To(Equal(origin)) + Expect(response.Host).To(Equal(proxyURL.Host)) }) It("will proxy HTTP requests", func() { @@ -530,5 +531,45 @@ var _ = Describe("HTTP Upstream Suite", func() { Expect(err).ToNot(HaveOccurred()) Expect(response.StatusCode).To(Equal(200)) }) + + It("will proxy websockets respecting PassHostHeader=false", func() { + // Create a new proxy server with PassHostHeader=false + flush := 1 * time.Second + timeout := options.DefaultUpstreamTimeout + upstream := options.Upstream{ + ID: "websocketProxyNoPassHost", + PassHostHeader: ptr.To(false), + ProxyWebSockets: ptr.To(true), + InsecureSkipTLSVerify: ptr.To(false), + FlushInterval: &flush, + Timeout: &timeout, + } + + u, err := url.Parse(serverAddr) + Expect(err).ToNot(HaveOccurred()) + + handler := newHTTPUpstreamProxy(upstream, u, nil, nil) + noPassHostServer := httptest.NewServer(middleware.NewScope(false, "X-Request-Id")(handler)) + defer noPassHostServer.Close() + + origin := "http://example.localhost" + message := "Hello, world!" + + proxyURL, err := url.Parse(fmt.Sprintf("http://%s", noPassHostServer.Listener.Addr().String())) + Expect(err).ToNot(HaveOccurred()) + + wsAddr := fmt.Sprintf("ws://%s/", proxyURL.Host) + ws, err := websocket.Dial(wsAddr, "", origin) + Expect(err).ToNot(HaveOccurred()) + + Expect(websocket.Message.Send(ws, []byte(message))).To(Succeed()) + var response testWebSocketResponse + Expect(websocket.JSON.Receive(ws, &response)).To(Succeed()) + + // When PassHostHeader=false, the Host should be the upstream server address + Expect(response.Host).To(Equal(u.Host)) + Expect(response.Message).To(Equal(message)) + Expect(response.Origin).To(Equal(origin)) + }) }) }) diff --git a/pkg/upstream/upstream_suite_test.go b/pkg/upstream/upstream_suite_test.go index f34c8eaf..56aa98b1 100644 --- a/pkg/upstream/upstream_suite_test.go +++ b/pkg/upstream/upstream_suite_test.go @@ -96,6 +96,7 @@ type testHTTPRequest struct { type testWebSocketResponse struct { Message string Origin string + Host string } type testHTTPUpstream struct{} @@ -138,6 +139,7 @@ func (t *testHTTPUpstream) websocketHandler() http.Handler { wsResponse := testWebSocketResponse{ Message: string(data), Origin: ws.Request().Header.Get("Origin"), + Host: ws.Request().Host, } err = websocket.JSON.Send(ws, wsResponse) if err != nil {