Merge pull request #82 from 18F/sign-in-redirect
Redirect to / when /oauth2/sign_in accessed
This commit is contained in:
		
						commit
						b0f0409f2b
					
				|  | @ -307,6 +307,11 @@ func (p *OauthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code | ||||||
| 	p.ClearCookie(rw, req) | 	p.ClearCookie(rw, req) | ||||||
| 	rw.WriteHeader(code) | 	rw.WriteHeader(code) | ||||||
| 
 | 
 | ||||||
|  | 	redirect_url := req.URL.RequestURI() | ||||||
|  | 	if redirect_url == signInPath { | ||||||
|  | 		redirect_url = "/" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	t := struct { | 	t := struct { | ||||||
| 		ProviderName  string | 		ProviderName  string | ||||||
| 		SignInMessage string | 		SignInMessage string | ||||||
|  | @ -317,7 +322,7 @@ func (p *OauthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code | ||||||
| 		ProviderName:  p.provider.Data().ProviderName, | 		ProviderName:  p.provider.Data().ProviderName, | ||||||
| 		SignInMessage: p.SignInMessage, | 		SignInMessage: p.SignInMessage, | ||||||
| 		CustomLogin:   p.displayCustomLoginForm(), | 		CustomLogin:   p.displayCustomLoginForm(), | ||||||
| 		Redirect:      req.URL.RequestURI(), | 		Redirect:      redirect_url, | ||||||
| 		Version:       VERSION, | 		Version:       VERSION, | ||||||
| 	} | 	} | ||||||
| 	p.templates.ExecuteTemplate(rw, "sign_in.html", t) | 	p.templates.ExecuteTemplate(rw, "sign_in.html", t) | ||||||
|  |  | ||||||
|  | @ -9,6 +9,7 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"net/url" | 	"net/url" | ||||||
|  | 	"regexp" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
|  | @ -237,3 +238,69 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) { | ||||||
| 	assert.Equal(t, 200, code) | 	assert.Equal(t, 200, code) | ||||||
| 	assert.Equal(t, "No access token found.", payload) | 	assert.Equal(t, "No access token found.", payload) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | type SignInPageTest struct { | ||||||
|  | 	opts           *Options | ||||||
|  | 	proxy          *OauthProxy | ||||||
|  | 	sign_in_regexp *regexp.Regexp | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">` | ||||||
|  | 
 | ||||||
|  | func NewSignInPageTest() *SignInPageTest { | ||||||
|  | 	var sip_test SignInPageTest | ||||||
|  | 
 | ||||||
|  | 	sip_test.opts = NewOptions() | ||||||
|  | 	sip_test.opts.Upstreams = append(sip_test.opts.Upstreams, "unused") | ||||||
|  | 	sip_test.opts.CookieSecret = "foobar" | ||||||
|  | 	sip_test.opts.ClientID = "bazquux" | ||||||
|  | 	sip_test.opts.ClientSecret = "xyzzyplugh" | ||||||
|  | 	sip_test.opts.Validate() | ||||||
|  | 
 | ||||||
|  | 	sip_test.proxy = NewOauthProxy(sip_test.opts, func(email string) bool { | ||||||
|  | 		return true | ||||||
|  | 	}) | ||||||
|  | 	sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern) | ||||||
|  | 
 | ||||||
|  | 	return &sip_test | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (sip_test *SignInPageTest) GetEndpoint(endpoint string) (int, string) { | ||||||
|  | 	rw := httptest.NewRecorder() | ||||||
|  | 	req, _ := http.NewRequest("GET", endpoint, strings.NewReader("")) | ||||||
|  | 	sip_test.proxy.ServeHTTP(rw, req) | ||||||
|  | 	return rw.Code, rw.Body.String() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestSignInPageIncludesTargetRedirect(t *testing.T) { | ||||||
|  | 	sip_test := NewSignInPageTest() | ||||||
|  | 	const endpoint = "/some/random/endpoint" | ||||||
|  | 
 | ||||||
|  | 	code, body := sip_test.GetEndpoint(endpoint) | ||||||
|  | 	assert.Equal(t, 403, code) | ||||||
|  | 
 | ||||||
|  | 	match := sip_test.sign_in_regexp.FindStringSubmatch(body) | ||||||
|  | 	if match == nil { | ||||||
|  | 		t.Fatal("Did not find pattern in body: " + | ||||||
|  | 			signInRedirectPattern + "\nBody:\n" + body) | ||||||
|  | 	} | ||||||
|  | 	if match[1] != endpoint { | ||||||
|  | 		t.Fatal(`expected redirect to "` + endpoint + | ||||||
|  | 			`", but was "` + match[1] + `"`) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { | ||||||
|  | 	sip_test := NewSignInPageTest() | ||||||
|  | 	code, body := sip_test.GetEndpoint("/oauth2/sign_in") | ||||||
|  | 	assert.Equal(t, 200, code) | ||||||
|  | 
 | ||||||
|  | 	match := sip_test.sign_in_regexp.FindStringSubmatch(body) | ||||||
|  | 	if match == nil { | ||||||
|  | 		t.Fatal("Did not find pattern in body: " + | ||||||
|  | 			signInRedirectPattern + "\nBody:\n" + body) | ||||||
|  | 	} | ||||||
|  | 	if match[1] != "/" { | ||||||
|  | 		t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue