oauthproxy: fix #284 -skip-provider-button for /sign_in route
This commit is contained in:
		
							parent
							
								
									3c51c914ac
								
							
						
					
					
						commit
						b640a69d63
					
				|  | @ -482,7 +482,11 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { | |||
| 		p.SaveSession(rw, req, session) | ||||
| 		http.Redirect(rw, req, redirect, 302) | ||||
| 	} else { | ||||
| 		p.SignInPage(rw, req, 200) | ||||
| 		if p.SkipProviderButton { | ||||
| 			p.OAuthStart(rw, req) | ||||
| 		} else { | ||||
| 			p.SignInPage(rw, req, http.StatusOK) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -3,9 +3,6 @@ package main | |||
| import ( | ||||
| 	"crypto" | ||||
| 	"encoding/base64" | ||||
| 	"github.com/18F/hmacauth" | ||||
| 	"github.com/bitly/oauth2_proxy/providers" | ||||
| 	"github.com/bmizerany/assert" | ||||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| 	"log" | ||||
|  | @ -17,6 +14,10 @@ import ( | |||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/18F/hmacauth" | ||||
| 	"github.com/bitly/oauth2_proxy/providers" | ||||
| 	"github.com/bmizerany/assert" | ||||
| ) | ||||
| 
 | ||||
| func init() { | ||||
|  | @ -359,26 +360,30 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| type SignInPageTest struct { | ||||
| 	opts           *Options | ||||
| 	proxy          *OAuthProxy | ||||
| 	sign_in_regexp *regexp.Regexp | ||||
| 	opts                    *Options | ||||
| 	proxy                   *OAuthProxy | ||||
| 	sign_in_regexp          *regexp.Regexp | ||||
| 	sign_in_provider_regexp *regexp.Regexp | ||||
| } | ||||
| 
 | ||||
| const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">` | ||||
| const signInSkipProvider = `>Found<` | ||||
| 
 | ||||
| func NewSignInPageTest() *SignInPageTest { | ||||
| func NewSignInPageTest(skipProvider bool) *SignInPageTest { | ||||
| 	var sip_test SignInPageTest | ||||
| 
 | ||||
| 	sip_test.opts = NewOptions() | ||||
| 	sip_test.opts.CookieSecret = "foobar" | ||||
| 	sip_test.opts.ClientID = "bazquux" | ||||
| 	sip_test.opts.ClientSecret = "xyzzyplugh" | ||||
| 	sip_test.opts.SkipProviderButton = skipProvider | ||||
| 	sip_test.opts.Validate() | ||||
| 
 | ||||
| 	sip_test.proxy = NewOAuthProxy(sip_test.opts, func(email string) bool { | ||||
| 		return true | ||||
| 	}) | ||||
| 	sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern) | ||||
| 	sip_test.sign_in_provider_regexp = regexp.MustCompile(signInSkipProvider) | ||||
| 
 | ||||
| 	return &sip_test | ||||
| } | ||||
|  | @ -391,7 +396,7 @@ func (sip_test *SignInPageTest) GetEndpoint(endpoint string) (int, string) { | |||
| } | ||||
| 
 | ||||
| func TestSignInPageIncludesTargetRedirect(t *testing.T) { | ||||
| 	sip_test := NewSignInPageTest() | ||||
| 	sip_test := NewSignInPageTest(false) | ||||
| 	const endpoint = "/some/random/endpoint" | ||||
| 
 | ||||
| 	code, body := sip_test.GetEndpoint(endpoint) | ||||
|  | @ -409,7 +414,7 @@ func TestSignInPageIncludesTargetRedirect(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { | ||||
| 	sip_test := NewSignInPageTest() | ||||
| 	sip_test := NewSignInPageTest(false) | ||||
| 	code, body := sip_test.GetEndpoint("/oauth2/sign_in") | ||||
| 	assert.Equal(t, 200, code) | ||||
| 
 | ||||
|  | @ -423,6 +428,34 @@ func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSignInPageSkipProvider(t *testing.T) { | ||||
| 	sip_test := NewSignInPageTest(true) | ||||
| 	const endpoint = "/some/random/endpoint" | ||||
| 
 | ||||
| 	code, body := sip_test.GetEndpoint(endpoint) | ||||
| 	assert.Equal(t, 302, code) | ||||
| 
 | ||||
| 	match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body) | ||||
| 	if match == nil { | ||||
| 		t.Fatal("Did not find pattern in body: " + | ||||
| 			signInSkipProvider + "\nBody:\n" + body) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestSignInPageSkipProviderDirect(t *testing.T) { | ||||
| 	sip_test := NewSignInPageTest(true) | ||||
| 	const endpoint = "/sign_in" | ||||
| 
 | ||||
| 	code, body := sip_test.GetEndpoint(endpoint) | ||||
| 	assert.Equal(t, 302, code) | ||||
| 
 | ||||
| 	match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body) | ||||
| 	if match == nil { | ||||
| 		t.Fatal("Did not find pattern in body: " + | ||||
| 			signInSkipProvider + "\nBody:\n" + body) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type ProcessCookieTest struct { | ||||
| 	opts          *Options | ||||
| 	proxy         *OAuthProxy | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue