Escape unencoded OAuth state redirects
This commit is contained in:
parent
c0a087d7f2
commit
76e566ef63
|
|
@ -1243,7 +1243,12 @@ func checkAllowedEmails(req *http.Request, s *sessionsapi.SessionState) bool {
|
||||||
// encodeState builds the OAuth state param out of our nonce and
|
// encodeState builds the OAuth state param out of our nonce and
|
||||||
// original application redirect
|
// original application redirect
|
||||||
func encodeState(nonce string, redirect string, encode bool) string {
|
func encodeState(nonce string, redirect string, encode bool) string {
|
||||||
rawString := fmt.Sprintf("%v:%v", nonce, redirect)
|
redirectPart := redirect
|
||||||
|
if !encode {
|
||||||
|
redirectPart = url.QueryEscape(redirectPart)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawString := fmt.Sprintf("%v:%v", nonce, redirectPart)
|
||||||
if encode {
|
if encode {
|
||||||
return base64.RawURLEncoding.EncodeToString([]byte(rawString))
|
return base64.RawURLEncoding.EncodeToString([]byte(rawString))
|
||||||
}
|
}
|
||||||
|
|
@ -1263,7 +1268,18 @@ func decodeState(state string, encode bool) (string, string, error) {
|
||||||
if len(parsedState) != 2 {
|
if len(parsedState) != 2 {
|
||||||
return "", "", errors.New("invalid length")
|
return "", "", errors.New("invalid length")
|
||||||
}
|
}
|
||||||
return parsedState[0], parsedState[1], nil
|
nonce := parsedState[0]
|
||||||
|
redirect := parsedState[1]
|
||||||
|
|
||||||
|
if !encode {
|
||||||
|
unescapedRedirect, err := url.QueryUnescape(redirect)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
redirect = unescapedRedirect
|
||||||
|
}
|
||||||
|
|
||||||
|
return nonce, redirect, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// addHeadersForProxying adds the appropriate headers the request / response for proxying
|
// addHeadersForProxying adds the appropriate headers the request / response for proxying
|
||||||
|
|
|
||||||
|
|
@ -3292,26 +3292,45 @@ func TestAuthOnlyAllowedEmailDomains(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStateEncodesCorrectly(t *testing.T) {
|
func TestStateEncodesCorrectly(t *testing.T) {
|
||||||
state := "some_state_to_test"
|
state := "https://example.com/callback?foo=bar&baz=qux"
|
||||||
nonce := "some_nonce_to_test"
|
nonce := "some_nonce_to_test"
|
||||||
|
|
||||||
encodedResult := encodeState(nonce, state, true)
|
encodedResult := encodeState(nonce, state, true)
|
||||||
assert.Equal(t, "c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", encodedResult)
|
expectedEncoded := base64.RawURLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", nonce, state)))
|
||||||
|
assert.Equal(t, expectedEncoded, encodedResult)
|
||||||
|
|
||||||
notEncodedResult := encodeState(nonce, state, false)
|
notEncodedResult := encodeState(nonce, state, false)
|
||||||
assert.Equal(t, "some_nonce_to_test:some_state_to_test", notEncodedResult)
|
expectedUnencoded := fmt.Sprintf("%s:%s", nonce, url.QueryEscape(state))
|
||||||
|
assert.Equal(t, expectedUnencoded, notEncodedResult)
|
||||||
|
assert.NotContains(t, notEncodedResult, "&")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStateDecodesCorrectly(t *testing.T) {
|
func TestStateDecodesCorrectly(t *testing.T) {
|
||||||
nonce, redirect, _ := decodeState("c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", true)
|
state := "https://example.com/callback?foo=bar&baz=qux"
|
||||||
|
nonce := "some_nonce_to_test"
|
||||||
|
|
||||||
assert.Equal(t, "some_nonce_to_test", nonce)
|
encodedState := base64.RawURLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", nonce, state)))
|
||||||
assert.Equal(t, "some_state_to_test", redirect)
|
decodedNonce, decodedRedirect, err := decodeState(encodedState, true)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, nonce, decodedNonce)
|
||||||
|
assert.Equal(t, state, decodedRedirect)
|
||||||
|
|
||||||
nonce2, redirect2, _ := decodeState("some_nonce_to_test:some_state_to_test", false)
|
rawState := fmt.Sprintf("%s:%s", nonce, url.QueryEscape(state))
|
||||||
|
decodedNonce2, decodedRedirect2, err := decodeState(rawState, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, nonce, decodedNonce2)
|
||||||
|
assert.Equal(t, state, decodedRedirect2)
|
||||||
|
}
|
||||||
|
|
||||||
assert.Equal(t, "some_nonce_to_test", nonce2)
|
func TestStateRoundTripWithMultipleQueryParameters(t *testing.T) {
|
||||||
assert.Equal(t, "some_state_to_test", redirect2)
|
state := "https://example.com/callback?foo=bar&baz=qux&zap=zazzle"
|
||||||
|
nonce := "another_nonce"
|
||||||
|
|
||||||
|
encoded := encodeState(nonce, state, false)
|
||||||
|
decodedNonce, decodedRedirect, err := decodeState(encoded, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, nonce, decodedNonce)
|
||||||
|
assert.Equal(t, state, decodedRedirect)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthOnlyAllowedEmails(t *testing.T) {
|
func TestAuthOnlyAllowedEmails(t *testing.T) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue