Merge pull request #82 from 18F/sign-in-redirect
Redirect to / when /oauth2/sign_in accessed
This commit is contained in:
		
						commit
						b0f0409f2b
					
				|  | @ -307,6 +307,11 @@ func (p *OauthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code | |||
| 	p.ClearCookie(rw, req) | ||||
| 	rw.WriteHeader(code) | ||||
| 
 | ||||
| 	redirect_url := req.URL.RequestURI() | ||||
| 	if redirect_url == signInPath { | ||||
| 		redirect_url = "/" | ||||
| 	} | ||||
| 
 | ||||
| 	t := struct { | ||||
| 		ProviderName  string | ||||
| 		SignInMessage string | ||||
|  | @ -317,7 +322,7 @@ func (p *OauthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code | |||
| 		ProviderName:  p.provider.Data().ProviderName, | ||||
| 		SignInMessage: p.SignInMessage, | ||||
| 		CustomLogin:   p.displayCustomLoginForm(), | ||||
| 		Redirect:      req.URL.RequestURI(), | ||||
| 		Redirect:      redirect_url, | ||||
| 		Version:       VERSION, | ||||
| 	} | ||||
| 	p.templates.ExecuteTemplate(rw, "sign_in.html", t) | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ import ( | |||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"net/url" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | @ -237,3 +238,69 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) { | |||
| 	assert.Equal(t, 200, code) | ||||
| 	assert.Equal(t, "No access token found.", payload) | ||||
| } | ||||
| 
 | ||||
| type SignInPageTest struct { | ||||
| 	opts           *Options | ||||
| 	proxy          *OauthProxy | ||||
| 	sign_in_regexp *regexp.Regexp | ||||
| } | ||||
| 
 | ||||
| const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">` | ||||
| 
 | ||||
| func NewSignInPageTest() *SignInPageTest { | ||||
| 	var sip_test SignInPageTest | ||||
| 
 | ||||
| 	sip_test.opts = NewOptions() | ||||
| 	sip_test.opts.Upstreams = append(sip_test.opts.Upstreams, "unused") | ||||
| 	sip_test.opts.CookieSecret = "foobar" | ||||
| 	sip_test.opts.ClientID = "bazquux" | ||||
| 	sip_test.opts.ClientSecret = "xyzzyplugh" | ||||
| 	sip_test.opts.Validate() | ||||
| 
 | ||||
| 	sip_test.proxy = NewOauthProxy(sip_test.opts, func(email string) bool { | ||||
| 		return true | ||||
| 	}) | ||||
| 	sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern) | ||||
| 
 | ||||
| 	return &sip_test | ||||
| } | ||||
| 
 | ||||
| func (sip_test *SignInPageTest) GetEndpoint(endpoint string) (int, string) { | ||||
| 	rw := httptest.NewRecorder() | ||||
| 	req, _ := http.NewRequest("GET", endpoint, strings.NewReader("")) | ||||
| 	sip_test.proxy.ServeHTTP(rw, req) | ||||
| 	return rw.Code, rw.Body.String() | ||||
| } | ||||
| 
 | ||||
| func TestSignInPageIncludesTargetRedirect(t *testing.T) { | ||||
| 	sip_test := NewSignInPageTest() | ||||
| 	const endpoint = "/some/random/endpoint" | ||||
| 
 | ||||
| 	code, body := sip_test.GetEndpoint(endpoint) | ||||
| 	assert.Equal(t, 403, code) | ||||
| 
 | ||||
| 	match := sip_test.sign_in_regexp.FindStringSubmatch(body) | ||||
| 	if match == nil { | ||||
| 		t.Fatal("Did not find pattern in body: " + | ||||
| 			signInRedirectPattern + "\nBody:\n" + body) | ||||
| 	} | ||||
| 	if match[1] != endpoint { | ||||
| 		t.Fatal(`expected redirect to "` + endpoint + | ||||
| 			`", but was "` + match[1] + `"`) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { | ||||
| 	sip_test := NewSignInPageTest() | ||||
| 	code, body := sip_test.GetEndpoint("/oauth2/sign_in") | ||||
| 	assert.Equal(t, 200, code) | ||||
| 
 | ||||
| 	match := sip_test.sign_in_regexp.FindStringSubmatch(body) | ||||
| 	if match == nil { | ||||
| 		t.Fatal("Did not find pattern in body: " + | ||||
| 			signInRedirectPattern + "\nBody:\n" + body) | ||||
| 	} | ||||
| 	if match[1] != "/" { | ||||
| 		t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`) | ||||
| 	} | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue