Escape unencoded OAuth state redirects
This commit is contained in:
		
							parent
							
								
									c0a087d7f2
								
							
						
					
					
						commit
						76e566ef63
					
				|  | @ -1243,7 +1243,12 @@ func checkAllowedEmails(req *http.Request, s *sessionsapi.SessionState) bool { | ||||||
| // encodeState builds the OAuth state param out of our nonce and
 | // encodeState builds the OAuth state param out of our nonce and
 | ||||||
| // original application redirect
 | // original application redirect
 | ||||||
| func encodeState(nonce string, redirect string, encode bool) string { | func encodeState(nonce string, redirect string, encode bool) string { | ||||||
| 	rawString := fmt.Sprintf("%v:%v", nonce, redirect) | 	redirectPart := redirect | ||||||
|  | 	if !encode { | ||||||
|  | 		redirectPart = url.QueryEscape(redirectPart) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	rawString := fmt.Sprintf("%v:%v", nonce, redirectPart) | ||||||
| 	if encode { | 	if encode { | ||||||
| 		return base64.RawURLEncoding.EncodeToString([]byte(rawString)) | 		return base64.RawURLEncoding.EncodeToString([]byte(rawString)) | ||||||
| 	} | 	} | ||||||
|  | @ -1263,7 +1268,18 @@ func decodeState(state string, encode bool) (string, string, error) { | ||||||
| 	if len(parsedState) != 2 { | 	if len(parsedState) != 2 { | ||||||
| 		return "", "", errors.New("invalid length") | 		return "", "", errors.New("invalid length") | ||||||
| 	} | 	} | ||||||
| 	return parsedState[0], parsedState[1], nil | 	nonce := parsedState[0] | ||||||
|  | 	redirect := parsedState[1] | ||||||
|  | 
 | ||||||
|  | 	if !encode { | ||||||
|  | 		unescapedRedirect, err := url.QueryUnescape(redirect) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return "", "", err | ||||||
|  | 		} | ||||||
|  | 		redirect = unescapedRedirect | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nonce, redirect, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // addHeadersForProxying adds the appropriate headers the request / response for proxying
 | // addHeadersForProxying adds the appropriate headers the request / response for proxying
 | ||||||
|  |  | ||||||
|  | @ -3292,26 +3292,45 @@ func TestAuthOnlyAllowedEmailDomains(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestStateEncodesCorrectly(t *testing.T) { | func TestStateEncodesCorrectly(t *testing.T) { | ||||||
| 	state := "some_state_to_test" | 	state := "https://example.com/callback?foo=bar&baz=qux" | ||||||
| 	nonce := "some_nonce_to_test" | 	nonce := "some_nonce_to_test" | ||||||
| 
 | 
 | ||||||
| 	encodedResult := encodeState(nonce, state, true) | 	encodedResult := encodeState(nonce, state, true) | ||||||
| 	assert.Equal(t, "c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", encodedResult) | 	expectedEncoded := base64.RawURLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", nonce, state))) | ||||||
|  | 	assert.Equal(t, expectedEncoded, encodedResult) | ||||||
| 
 | 
 | ||||||
| 	notEncodedResult := encodeState(nonce, state, false) | 	notEncodedResult := encodeState(nonce, state, false) | ||||||
| 	assert.Equal(t, "some_nonce_to_test:some_state_to_test", notEncodedResult) | 	expectedUnencoded := fmt.Sprintf("%s:%s", nonce, url.QueryEscape(state)) | ||||||
|  | 	assert.Equal(t, expectedUnencoded, notEncodedResult) | ||||||
|  | 	assert.NotContains(t, notEncodedResult, "&") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestStateDecodesCorrectly(t *testing.T) { | func TestStateDecodesCorrectly(t *testing.T) { | ||||||
| 	nonce, redirect, _ := decodeState("c29tZV9ub25jZV90b190ZXN0OnNvbWVfc3RhdGVfdG9fdGVzdA", true) | 	state := "https://example.com/callback?foo=bar&baz=qux" | ||||||
|  | 	nonce := "some_nonce_to_test" | ||||||
| 
 | 
 | ||||||
| 	assert.Equal(t, "some_nonce_to_test", nonce) | 	encodedState := base64.RawURLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", nonce, state))) | ||||||
| 	assert.Equal(t, "some_state_to_test", redirect) | 	decodedNonce, decodedRedirect, err := decodeState(encodedState, true) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.Equal(t, nonce, decodedNonce) | ||||||
|  | 	assert.Equal(t, state, decodedRedirect) | ||||||
| 
 | 
 | ||||||
| 	nonce2, redirect2, _ := decodeState("some_nonce_to_test:some_state_to_test", false) | 	rawState := fmt.Sprintf("%s:%s", nonce, url.QueryEscape(state)) | ||||||
|  | 	decodedNonce2, decodedRedirect2, err := decodeState(rawState, false) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.Equal(t, nonce, decodedNonce2) | ||||||
|  | 	assert.Equal(t, state, decodedRedirect2) | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| 	assert.Equal(t, "some_nonce_to_test", nonce2) | func TestStateRoundTripWithMultipleQueryParameters(t *testing.T) { | ||||||
| 	assert.Equal(t, "some_state_to_test", redirect2) | 	state := "https://example.com/callback?foo=bar&baz=qux&zap=zazzle" | ||||||
|  | 	nonce := "another_nonce" | ||||||
|  | 
 | ||||||
|  | 	encoded := encodeState(nonce, state, false) | ||||||
|  | 	decodedNonce, decodedRedirect, err := decodeState(encoded, false) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.Equal(t, nonce, decodedNonce) | ||||||
|  | 	assert.Equal(t, state, decodedRedirect) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestAuthOnlyAllowedEmails(t *testing.T) { | func TestAuthOnlyAllowedEmails(t *testing.T) { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue