From 1b3b00443ae6b4a56f92cede54433e1c558afc2e Mon Sep 17 00:00:00 2001 From: Nick Meves Date: Fri, 23 Oct 2020 19:35:15 -0700 Subject: [PATCH] Streamline ErrMissingCode in provider Redeem methods --- oauthproxy.go | 2 +- oauthproxy_test.go | 5 +++++ providers/azure.go | 14 ++++++-------- providers/google.go | 3 +-- providers/logingov.go | 17 +++++++---------- providers/provider_default.go | 32 ++++++++++++++++---------------- 6 files changed, 36 insertions(+), 37 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index 7d69cd5e..98c82238 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -394,7 +394,7 @@ func (p *OAuthProxy) GetRedirectURI(host string) string { func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) { if code == "" { - return nil, errors.New("missing code") + return nil, providers.ErrMissingCode } redirectURI := p.GetRedirectURI(host) s, err := p.provider.Redeem(ctx, redirectURI, code) diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 132c5d0a..a2733f6d 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1133,6 +1133,11 @@ func TestUserInfoEndpointAccepted(t *testing.T) { Email: "john.doe@example.com", AccessToken: "my_access_token"} err = test.SaveSession(startSession) assert.NoError(t, err) + + test.proxy.ServeHTTP(test.rw, test.req) + assert.Equal(t, http.StatusOK, test.rw.Code) + bodyBytes, _ := ioutil.ReadAll(test.rw.Body) + assert.Equal(t, "{\"email\":\"john.doe@example.com\"}\n", string(bodyBytes)) } func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { diff --git a/providers/azure.go b/providers/azure.go index d65b11f4..e72f1068 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -108,14 +108,13 @@ func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) { } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { if code == "" { - err = errors.New("missing code") - return + return nil, ErrMissingCode } clientSecret, err := p.GetClientSecret() if err != nil { - return + return nil, err } params := url.Values{} @@ -149,15 +148,14 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s created := time.Now() expires := time.Unix(jsonResponse.ExpiresOn, 0) - s = &sessions.SessionState{ + + return &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, CreatedAt: &created, ExpiresOn: &expires, RefreshToken: jsonResponse.RefreshToken, - } - return - + }, nil } // RefreshSessionIfNeeded checks if the session has expired and uses the diff --git a/providers/google.go b/providers/google.go index 640f40cf..d355efca 100644 --- a/providers/google.go +++ b/providers/google.go @@ -126,8 +126,7 @@ func claimsFromIDToken(idToken string) (*claims, error) { // Redeem exchanges the OAuth2 authentication token for an ID token func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { if code == "" { - err := errors.New("missing code") - return nil, err + return nil, ErrMissingCode } clientSecret, err := p.GetClientSecret() if err != nil { diff --git a/providers/logingov.go b/providers/logingov.go index ff48ccc5..44d1cb46 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "crypto/rsa" - "errors" "fmt" "math/rand" "net/url" @@ -153,10 +152,9 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { if code == "" { - err = errors.New("missing code") - return + return nil, ErrMissingCode } claims := &jwt.StandardClaims{ @@ -169,7 +167,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims) ss, err := token.SignedString(p.JWTKey) if err != nil { - return + return nil, err } params := url.Values{} @@ -199,28 +197,27 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) // check nonce here err = checkNonce(jsonResponse.IDToken, p) if err != nil { - return + return nil, err } // Get the email address var email string email, err = emailFromUserInfo(ctx, jsonResponse.AccessToken, p.ProfileURL.String()) if err != nil { - return + return nil, err } created := time.Now() expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) // Store the data that we found in the session state - s = &sessions.SessionState{ + return &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, CreatedAt: &created, ExpiresOn: &expires, Email: email, - } - return + }, nil } // GetLoginURL overrides GetLoginURL to add login.gov parameters diff --git a/providers/provider_default.go b/providers/provider_default.go index f7d5ac0e..00b70641 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -19,18 +19,21 @@ var ( // implementation method that doesn't have sensible defaults ErrNotImplemented = errors.New("not implemented") + // ErrMissingCode is returned when a Redeem method is called with an empty + // code + ErrMissingCode = errors.New("missing code") + _ Provider = (*ProviderData)(nil) ) // Redeem provides a default implementation of the OAuth2 token redemption process -func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { +func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { if code == "" { - err = errors.New("missing code") - return + return nil, ErrMissingCode } clientSecret, err := p.GetClientSecret() if err != nil { - return + return nil, err } params := url.Values{} @@ -59,24 +62,21 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s } err = result.UnmarshalInto(&jsonResponse) if err == nil { - s = &sessions.SessionState{ + return &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, - } - return + }, nil } - var v url.Values - v, err = url.ParseQuery(string(result.Body())) + values, err := url.ParseQuery(string(result.Body())) if err != nil { - return + return nil, err } - if a := v.Get("access_token"); a != "" { + if token := values.Get("access_token"); token != "" { created := time.Now() - s = &sessions.SessionState{AccessToken: a, CreatedAt: &created} - } else { - err = fmt.Errorf("no access token found %s", result.Body()) + return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil } - return + + return nil, fmt.Errorf("no access token found %s", result.Body()) } // GetLoginURL with typical oauth parameters @@ -100,7 +100,7 @@ func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.Session // Authorize performs global authorization on an authenticated session. // This is not used for fine-grained per route authorization rules. -func (p *ProviderData) Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) { +func (p *ProviderData) Authorize(_ context.Context, s *sessions.SessionState) (bool, error) { if len(p.AllowedGroups) == 0 { return true, nil }