From 76e566ef63895de9e2a9221b9014484ff77dc2a8 Mon Sep 17 00:00:00 2001 From: CallumWayve Date: Fri, 26 Sep 2025 17:21:03 +0100 Subject: [PATCH] Escape unencoded OAuth state redirects --- oauthproxy.go | 20 ++++++++++++++++++-- oauthproxy_test.go | 37 ++++++++++++++++++++++++++++--------- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index 7526d641..8c484c04 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -1243,7 +1243,12 @@ func checkAllowedEmails(req *http.Request, s *sessionsapi.SessionState) bool { // encodeState builds the OAuth state param out of our nonce and // original application redirect func encodeState(nonce string, redirect string, encode bool) string { - rawString := fmt.Sprintf("%v:%v", nonce, redirect) + redirectPart := redirect + if !encode { + redirectPart = url.QueryEscape(redirectPart) + } + + rawString := fmt.Sprintf("%v:%v", nonce, redirectPart) if encode { return base64.RawURLEncoding.EncodeToString([]byte(rawString)) } @@ -1263,7 +1268,18 @@ func decodeState(state string, encode bool) (string, string, error) { if len(parsedState) != 2 { return "", "", errors.New("invalid length") } - return parsedState[0], parsedState[1], nil + nonce := parsedState[0] + redirect := parsedState[1] + + if !encode { + unescapedRedirect, err := url.QueryUnescape(redirect) + if err != nil { + return "", "", err + } + redirect = unescapedRedirect + } + + return nonce, redirect, nil } // addHeadersForProxying adds the appropriate headers the request / response for proxying diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 488b8cea..da3da0ef 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -3292,26 +3292,45 @@ func TestAuthOnlyAllowedEmailDomains(t *testing.T) { } func TestStateEncodesCorrectly(t *testing.T) { - state := "some_state_to_test" + state := "https://example.com/callback?foo=bar&baz=qux" nonce := "some_nonce_to_test" encodedResult := encodeState(nonce, state, true) - assert.Equal(t, "c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", encodedResult) + expectedEncoded := base64.RawURLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", nonce, state))) + assert.Equal(t, expectedEncoded, encodedResult) notEncodedResult := encodeState(nonce, state, false) - assert.Equal(t, "some_nonce_to_test:some_state_to_test", notEncodedResult) + expectedUnencoded := fmt.Sprintf("%s:%s", nonce, url.QueryEscape(state)) + assert.Equal(t, expectedUnencoded, notEncodedResult) + assert.NotContains(t, notEncodedResult, "&") } func TestStateDecodesCorrectly(t *testing.T) { - nonce, redirect, _ := decodeState("c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", true) + state := "https://example.com/callback?foo=bar&baz=qux" + nonce := "some_nonce_to_test" - assert.Equal(t, "some_nonce_to_test", nonce) - assert.Equal(t, "some_state_to_test", redirect) + encodedState := base64.RawURLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", nonce, state))) + decodedNonce, decodedRedirect, err := decodeState(encodedState, true) + assert.NoError(t, err) + assert.Equal(t, nonce, decodedNonce) + assert.Equal(t, state, decodedRedirect) - nonce2, redirect2, _ := decodeState("some_nonce_to_test:some_state_to_test", false) + rawState := fmt.Sprintf("%s:%s", nonce, url.QueryEscape(state)) + decodedNonce2, decodedRedirect2, err := decodeState(rawState, false) + assert.NoError(t, err) + assert.Equal(t, nonce, decodedNonce2) + assert.Equal(t, state, decodedRedirect2) +} - assert.Equal(t, "some_nonce_to_test", nonce2) - assert.Equal(t, "some_state_to_test", redirect2) +func TestStateRoundTripWithMultipleQueryParameters(t *testing.T) { + state := "https://example.com/callback?foo=bar&baz=qux&zap=zazzle" + nonce := "another_nonce" + + encoded := encodeState(nonce, state, false) + decodedNonce, decodedRedirect, err := decodeState(encoded, false) + assert.NoError(t, err) + assert.Equal(t, nonce, decodedNonce) + assert.Equal(t, state, decodedRedirect) } func TestAuthOnlyAllowedEmails(t *testing.T) {