Returns HTTP unauthorized for ajax requests instead of redirecting to the sing-in page
This commit is contained in:
		
							parent
							
								
									01c5f5ae3b
								
							
						
					
					
						commit
						c12db0ebf7
					
				|  | @ -754,6 +754,8 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { | ||||||
| 		} else { | 		} else { | ||||||
| 			p.SignInPage(rw, req, http.StatusForbidden) | 			p.SignInPage(rw, req, http.StatusForbidden) | ||||||
| 		} | 		} | ||||||
|  | 	} else if status == http.StatusUnauthorized { | ||||||
|  | 		p.ErrorJSON(rw, status) | ||||||
| 	} else { | 	} else { | ||||||
| 		p.serveMux.ServeHTTP(rw, req) | 		p.serveMux.ServeHTTP(rw, req) | ||||||
| 	} | 	} | ||||||
|  | @ -826,6 +828,11 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if session == nil { | 	if session == nil { | ||||||
|  | 		// Check if is an ajax request and return unauthorized to avoid a redirect
 | ||||||
|  | 		// to the login page
 | ||||||
|  | 		if p.isAjax(req) { | ||||||
|  | 			return http.StatusUnauthorized | ||||||
|  | 		} | ||||||
| 		return http.StatusForbidden | 		return http.StatusForbidden | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -894,3 +901,24 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, | ||||||
| 	} | 	} | ||||||
| 	return nil, fmt.Errorf("%s not in HtpasswdFile", pair[0]) | 	return nil, fmt.Errorf("%s not in HtpasswdFile", pair[0]) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // isAjax checks if a request is an ajax request
 | ||||||
|  | func (p *OAuthProxy) isAjax(req *http.Request) bool { | ||||||
|  | 	acceptValues, ok := req.Header["accept"] | ||||||
|  | 	if !ok { | ||||||
|  | 		acceptValues = req.Header["Accept"] | ||||||
|  | 	} | ||||||
|  | 	const ajaxReq = "application/json" | ||||||
|  | 	for _, v := range acceptValues { | ||||||
|  | 		if v == ajaxReq { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ErrorJSON returns the error code witht an application/json mime type
 | ||||||
|  | func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { | ||||||
|  | 	rw.Header().Set("Content-Type", "application/json") | ||||||
|  | 	rw.WriteHeader(code) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -871,3 +871,67 @@ func TestGetRedirect(t *testing.T) { | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | type ajaxRequestTest struct { | ||||||
|  | 	opts  *Options | ||||||
|  | 	proxy *OAuthProxy | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func newAjaxRequestTest() *ajaxRequestTest { | ||||||
|  | 	test := &ajaxRequestTest{} | ||||||
|  | 	test.opts = NewOptions() | ||||||
|  | 	test.opts.CookieSecret = "foobar" | ||||||
|  | 	test.opts.ClientID = "bazquux" | ||||||
|  | 	test.opts.ClientSecret = "xyzzyplugh" | ||||||
|  | 	test.opts.Validate() | ||||||
|  | 	test.proxy = NewOAuthProxy(test.opts, func(email string) bool { | ||||||
|  | 		return true | ||||||
|  | 	}) | ||||||
|  | 	return test | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (int, http.Header, error) { | ||||||
|  | 	rw := httptest.NewRecorder() | ||||||
|  | 	req, err := http.NewRequest(http.MethodGet, endpoint, strings.NewReader("")) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, nil, err | ||||||
|  | 	} | ||||||
|  | 	req.Header = header | ||||||
|  | 	test.proxy.ServeHTTP(rw, req) | ||||||
|  | 	return rw.Code, rw.Header(), nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func testAjaxUnauthorizedRequest(t *testing.T, header http.Header) { | ||||||
|  | 	test := newAjaxRequestTest() | ||||||
|  | 	const endpoint = "/test" | ||||||
|  | 
 | ||||||
|  | 	code, rh, err := test.getEndpoint(endpoint, header) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.Equal(t, http.StatusUnauthorized, code) | ||||||
|  | 	mime := rh.Get("Content-Type") | ||||||
|  | 	assert.Equal(t, "application/json", mime) | ||||||
|  | } | ||||||
|  | func TestAjaxUnauthorizedRequest1(t *testing.T) { | ||||||
|  | 	header := make(http.Header) | ||||||
|  | 	header.Add("accept", "application/json") | ||||||
|  | 
 | ||||||
|  | 	testAjaxUnauthorizedRequest(t, header) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestAjaxUnauthorizedRequest2(t *testing.T) { | ||||||
|  | 	header := make(http.Header) | ||||||
|  | 	header.Add("Accept", "application/json") | ||||||
|  | 
 | ||||||
|  | 	testAjaxUnauthorizedRequest(t, header) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestAjaxForbiddendRequest(t *testing.T) { | ||||||
|  | 	test := newAjaxRequestTest() | ||||||
|  | 	const endpoint = "/test" | ||||||
|  | 	header := make(http.Header) | ||||||
|  | 	code, rh, err := test.getEndpoint(endpoint, header) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.Equal(t, http.StatusForbidden, code) | ||||||
|  | 	mime := rh.Get("Content-Type") | ||||||
|  | 	assert.NotEqual(t, "application/json", mime) | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue