This commit is contained in:
CallumWayve 2025-10-28 12:39:11 +01:00 committed by GitHub
commit 9b9f04beca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 11 deletions

View File

@ -1243,7 +1243,12 @@ func checkAllowedEmails(req *http.Request, s *sessionsapi.SessionState) bool {
// encodeState builds the OAuth state param out of our nonce and
// original application redirect
func encodeState(nonce string, redirect string, encode bool) string {
rawString := fmt.Sprintf("%v:%v", nonce, redirect)
redirectPart := redirect
if !encode {
redirectPart = url.QueryEscape(redirectPart)
}
rawString := fmt.Sprintf("%v:%v", nonce, redirectPart)
if encode {
return base64.RawURLEncoding.EncodeToString([]byte(rawString))
}
@ -1263,7 +1268,18 @@ func decodeState(state string, encode bool) (string, string, error) {
if len(parsedState) != 2 {
return "", "", errors.New("invalid length")
}
return parsedState[0], parsedState[1], nil
nonce := parsedState[0]
redirect := parsedState[1]
if !encode {
unescapedRedirect, err := url.QueryUnescape(redirect)
if err != nil {
return "", "", err
}
redirect = unescapedRedirect
}
return nonce, redirect, nil
}
// addHeadersForProxying adds the appropriate headers the request / response for proxying

View File

@ -3292,26 +3292,45 @@ func TestAuthOnlyAllowedEmailDomains(t *testing.T) {
}
func TestStateEncodesCorrectly(t *testing.T) {
state := "some_state_to_test"
state := "https://example.com/callback?foo=bar&baz=qux"
nonce := "some_nonce_to_test"
encodedResult := encodeState(nonce, state, true)
assert.Equal(t, "c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", encodedResult)
expectedEncoded := base64.RawURLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", nonce, state)))
assert.Equal(t, expectedEncoded, encodedResult)
notEncodedResult := encodeState(nonce, state, false)
assert.Equal(t, "some_nonce_to_test:some_state_to_test", notEncodedResult)
expectedUnencoded := fmt.Sprintf("%s:%s", nonce, url.QueryEscape(state))
assert.Equal(t, expectedUnencoded, notEncodedResult)
assert.NotContains(t, notEncodedResult, "&")
}
func TestStateDecodesCorrectly(t *testing.T) {
nonce, redirect, _ := decodeState("c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", true)
state := "https://example.com/callback?foo=bar&baz=qux"
nonce := "some_nonce_to_test"
assert.Equal(t, "some_nonce_to_test", nonce)
assert.Equal(t, "some_state_to_test", redirect)
encodedState := base64.RawURLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", nonce, state)))
decodedNonce, decodedRedirect, err := decodeState(encodedState, true)
assert.NoError(t, err)
assert.Equal(t, nonce, decodedNonce)
assert.Equal(t, state, decodedRedirect)
nonce2, redirect2, _ := decodeState("some_nonce_to_test:some_state_to_test", false)
rawState := fmt.Sprintf("%s:%s", nonce, url.QueryEscape(state))
decodedNonce2, decodedRedirect2, err := decodeState(rawState, false)
assert.NoError(t, err)
assert.Equal(t, nonce, decodedNonce2)
assert.Equal(t, state, decodedRedirect2)
}
assert.Equal(t, "some_nonce_to_test", nonce2)
assert.Equal(t, "some_state_to_test", redirect2)
func TestStateRoundTripWithMultipleQueryParameters(t *testing.T) {
state := "https://example.com/callback?foo=bar&baz=qux&zap=zazzle"
nonce := "another_nonce"
encoded := encodeState(nonce, state, false)
decodedNonce, decodedRedirect, err := decodeState(encoded, false)
assert.NoError(t, err)
assert.Equal(t, nonce, decodedNonce)
assert.Equal(t, state, decodedRedirect)
}
func TestAuthOnlyAllowedEmails(t *testing.T) {