Merge pull request #796 from grnhse/refactor-redeem-code

Refactor redeemCode and support a Provider-wide EnrichSessionState method
This commit is contained in:
Joel Speed 2020-10-21 11:09:36 +01:00 committed by GitHub
commit 2aa04c9720
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 248 additions and 206 deletions

View File

@ -26,6 +26,7 @@
## Changes since v6.1.1 ## Changes since v6.1.1
- [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed) - [#825](https://github.com/oauth2-proxy/oauth2-proxy/pull/825) Fix code coverage reporting on GitHub actions(@JoelSpeed)
- [#796](https://github.com/oauth2-proxy/oauth2-proxy/pull/796) Deprecate GetUserName & GetEmailAdress for EnrichSessionState (@NickMeves)
- [#705](https://github.com/oauth2-proxy/oauth2-proxy/pull/705) Add generic Header injectors for upstream request and response headers (@JoelSpeed) - [#705](https://github.com/oauth2-proxy/oauth2-proxy/pull/705) Add generic Header injectors for upstream request and response headers (@JoelSpeed)
- [#753](https://github.com/oauth2-proxy/oauth2-proxy/pull/753) Pass resource parameter in login url (@codablock) - [#753](https://github.com/oauth2-proxy/oauth2-proxy/pull/753) Pass resource parameter in login url (@codablock)
- [#789](https://github.com/oauth2-proxy/oauth2-proxy/pull/789) Add `--skip-auth-route` configuration option for `METHOD=pathRegex` based allowlists (@NickMeves) - [#789](https://github.com/oauth2-proxy/oauth2-proxy/pull/789) Add `--skip-auth-route` configuration option for `METHOD=pathRegex` based allowlists (@NickMeves)

View File

@ -357,22 +357,19 @@ func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessio
if err != nil { if err != nil {
return nil, err return nil, err
} }
return s, nil
}
func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error {
var err error
if s.Email == "" { if s.Email == "" {
s.Email, err = p.provider.GetEmailAddress(ctx, s) s.Email, err = p.provider.GetEmailAddress(ctx, s)
if err != nil && err.Error() != "not implemented" { if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
return nil, err return err
} }
} }
if s.User == "" { return p.provider.EnrichSessionState(ctx, s)
s.User, err = p.provider.GetUserName(ctx, s)
if err != nil && err.Error() != "not implemented" {
return nil, err
}
}
return s, nil
} }
// MakeCSRFCookie creates a cookie for CSRF // MakeCSRFCookie creates a cookie for CSRF
@ -829,14 +826,21 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
return return
} }
s := strings.SplitN(req.Form.Get("state"), ":", 2) err = p.enrichSessionState(req.Context(), session)
if len(s) != 2 { if err != nil {
logger.Errorf("Error creating session during OAuth2 callback: %v", err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error")
return
}
state := strings.SplitN(req.Form.Get("state"), ":", 2)
if len(state) != 2 {
logger.Error("Error while parsing OAuth2 state: invalid length") logger.Error("Error while parsing OAuth2 state: invalid length")
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Invalid State") p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Invalid State")
return return
} }
nonce := s[0] nonce := state[0]
redirect := s[1] redirect := state[1]
c, err := req.Cookie(p.CSRFCookieName) c, err := req.Cookie(p.CSRFCookieName)
if err != nil { if err != nil {
logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie") logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie")

View File

@ -396,14 +396,86 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
} }
} }
func (tp *TestProvider) GetEmailAddress(ctx context.Context, session *sessions.SessionState) (string, error) { func (tp *TestProvider) GetEmailAddress(_ context.Context, _ *sessions.SessionState) (string, error) {
return tp.EmailAddress, nil return tp.EmailAddress, nil
} }
func (tp *TestProvider) ValidateSessionState(ctx context.Context, session *sessions.SessionState) bool { func (tp *TestProvider) ValidateSessionState(_ context.Context, _ *sessions.SessionState) bool {
return tp.ValidToken return tp.ValidToken
} }
func Test_redeemCode(t *testing.T) {
opts := baseTestOptions()
err := validation.Validate(opts)
assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
if err != nil {
t.Fatal(err)
}
_, err = proxy.redeemCode(context.Background(), "www.example.com", "")
assert.Error(t, err)
}
func Test_enrichSession(t *testing.T) {
const (
sessionUser = "Mr Session"
sessionEmail = "session@example.com"
providerEmail = "provider@example.com"
)
testCases := map[string]struct {
session *sessions.SessionState
expectedUser string
expectedEmail string
}{
"Session already has enrichable fields": {
session: &sessions.SessionState{
User: sessionUser,
Email: sessionEmail,
},
expectedUser: sessionUser,
expectedEmail: sessionEmail,
},
"Session is missing Email and GetEmailAddress is implemented": {
session: &sessions.SessionState{
User: sessionUser,
},
expectedUser: sessionUser,
expectedEmail: providerEmail,
},
"Session is missing User and GetUserName is not implemented": {
session: &sessions.SessionState{
Email: sessionEmail,
},
expectedUser: "",
expectedEmail: sessionEmail,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
opts := baseTestOptions()
err := validation.Validate(opts)
assert.NoError(t, err)
// intentionally set after validation.Validate(opts) since it will clobber
// our TestProvider and call `providers.New` defaulting to `providers.GoogleProvider`
opts.SetProvider(NewTestProvider(&url.URL{Host: "www.example.com"}, providerEmail))
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
if err != nil {
t.Fatal(err)
}
err = proxy.enrichSessionState(context.Background(), tc.session)
assert.NoError(t, err)
assert.Equal(t, tc.expectedUser, tc.session.User)
assert.Equal(t, tc.expectedEmail, tc.session.Email)
})
}
}
func TestBasicAuthPassword(t *testing.T) { func TestBasicAuthPassword(t *testing.T) {
providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Printf("%#v", r) logger.Printf("%#v", r)
@ -1883,7 +1955,7 @@ func TestClearSingleCookie(t *testing.T) {
type NoOpKeySet struct { type NoOpKeySet struct {
} }
func (NoOpKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { func (NoOpKeySet) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) {
splitStrings := strings.Split(jwt, ".") splitStrings := strings.Split(jwt, ".")
payloadString := splitStrings[1] payloadString := splitStrings[1]
return base64.RawURLEncoding.DecodeString(payloadString) return base64.RawURLEncoding.DecodeString(payloadString)

View File

@ -273,7 +273,6 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
case *providers.GitLabProvider: case *providers.GitLabProvider:
p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail
p.Groups = o.GitLabGroup p.Groups = o.GitLabGroup
p.EmailDomains = o.EmailDomains
if o.GetOIDCVerifier() != nil { if o.GetOIDCVerifier() != nil {
p.Verifier = o.GetOIDCVerifier() p.Verifier = o.GetOIDCVerifier()

View File

@ -102,6 +102,20 @@ func (p *GitHubProvider) SetUsers(users []string) {
p.Users = users p.Users = users
} }
// EnrichSessionState updates the User & Email after the initial Redeem
func (p *GitHubProvider) EnrichSessionState(ctx context.Context, s *sessions.SessionState) error {
err := p.getEmail(ctx, s)
if err != nil {
return err
}
return p.getUser(ctx, s)
}
// ValidateSessionState validates the AccessToken
func (p *GitHubProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
return validateToken(ctx, p, s.AccessToken, makeGitHubHeader(s.AccessToken))
}
func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool, error) { func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool, error) {
// https://developer.github.com/v3/orgs/#list-your-organizations // https://developer.github.com/v3/orgs/#list-your-organizations
@ -364,8 +378,8 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok
return true, nil return true, nil
} }
// GetEmailAddress returns the Account email address // getEmail updates the SessionState Email
func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { func (p *GitHubProvider) getEmail(ctx context.Context, s *sessions.SessionState) error {
var emails []struct { var emails []struct {
Email string `json:"email"` Email string `json:"email"`
@ -379,11 +393,11 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio
var err error var err error
verifiedUser, err = p.hasUser(ctx, s.AccessToken) verifiedUser, err = p.hasUser(ctx, s.AccessToken)
if err != nil { if err != nil {
return "", err return err
} }
// org and repository options are not configured // org and repository options are not configured
if !verifiedUser && p.Org == "" && p.Repo == "" { if !verifiedUser && p.Org == "" && p.Repo == "" {
return "", errors.New("missing github user") return errors.New("missing github user")
} }
} }
// If a user is verified by username options, skip the following restrictions // If a user is verified by username options, skip the following restrictions
@ -391,16 +405,16 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio
if p.Org != "" { if p.Org != "" {
if p.Team != "" { if p.Team != "" {
if ok, err := p.hasOrgAndTeam(ctx, s.AccessToken); err != nil || !ok { if ok, err := p.hasOrgAndTeam(ctx, s.AccessToken); err != nil || !ok {
return "", err return err
} }
} else { } else {
if ok, err := p.hasOrg(ctx, s.AccessToken); err != nil || !ok { if ok, err := p.hasOrg(ctx, s.AccessToken); err != nil || !ok {
return "", err return err
} }
} }
} else if p.Repo != "" && p.Token == "" { // If we have a token we'll do the collaborator check in GetUserName } else if p.Repo != "" && p.Token == "" { // If we have a token we'll do the collaborator check in GetUserName
if ok, err := p.hasRepo(ctx, s.AccessToken); err != nil || !ok { if ok, err := p.hasRepo(ctx, s.AccessToken); err != nil || !ok {
return "", err return err
} }
} }
} }
@ -416,24 +430,23 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio
Do(). Do().
UnmarshalInto(&emails) UnmarshalInto(&emails)
if err != nil { if err != nil {
return "", err return err
} }
returnEmail := ""
for _, email := range emails { for _, email := range emails {
if email.Verified { if email.Verified {
returnEmail = email.Email
if email.Primary { if email.Primary {
return returnEmail, nil s.Email = email.Email
return nil
} }
} }
} }
return returnEmail, nil return nil
} }
// GetUserName returns the Account user name // getUser updates the SessionState User
func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionState) (string, error) { func (p *GitHubProvider) getUser(ctx context.Context, s *sessions.SessionState) error {
var user struct { var user struct {
Login string `json:"login"` Login string `json:"login"`
Email string `json:"email"` Email string `json:"email"`
@ -451,22 +464,18 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta
Do(). Do().
UnmarshalInto(&user) UnmarshalInto(&user)
if err != nil { if err != nil {
return "", err return err
} }
// Now that we have the username we can check collaborator status // Now that we have the username we can check collaborator status
if !p.isVerifiedUser(user.Login) && p.Org == "" && p.Repo != "" && p.Token != "" { if !p.isVerifiedUser(user.Login) && p.Org == "" && p.Repo != "" && p.Token != "" {
if ok, err := p.isCollaborator(ctx, user.Login, p.Token); err != nil || !ok { if ok, err := p.isCollaborator(ctx, user.Login, p.Token); err != nil || !ok {
return "", err return err
} }
} }
return user.Login, nil s.User = user.Login
} return nil
// ValidateSessionState validates the AccessToken
func (p *GitHubProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
return validateToken(ctx, p, s.AccessToken, makeGitHubHeader(s.AccessToken))
} }
// isVerifiedUser // isVerifiedUser

View File

@ -107,7 +107,7 @@ func TestGitHubProviderOverrides(t *testing.T) {
assert.Equal(t, "profile", p.Data().Scope) assert.Equal(t, "profile", p.Data().Scope)
} }
func TestGitHubProviderGetEmailAddress(t *testing.T) { func TestGitHubProvider_getEmail(t *testing.T) {
b := testGitHubBackend(map[string][]string{ b := testGitHubBackend(map[string][]string{
"/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`}, "/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`},
}) })
@ -117,14 +117,14 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) {
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session) err := p.getEmail(context.Background(), session)
assert.Equal(t, nil, err) assert.NoError(t, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", session.Email)
} }
func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) { func TestGitHubProvider_getEmailNotVerified(t *testing.T) {
b := testGitHubBackend(map[string][]string{ b := testGitHubBackend(map[string][]string{
"/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`}, "/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": false, "primary": true} ]`},
}) })
defer b.Close() defer b.Close()
@ -132,12 +132,12 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) {
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session) err := p.getEmail(context.Background(), session)
assert.Equal(t, nil, err) assert.NoError(t, err)
assert.Empty(t, "", email) assert.Empty(t, session.Email)
} }
func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) { func TestGitHubProvider_getEmailWithOrg(t *testing.T) {
b := testGitHubBackend(map[string][]string{ b := testGitHubBackend(map[string][]string{
"/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`}, "/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`},
"/user/orgs": { "/user/orgs": {
@ -153,12 +153,12 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) {
p.Org = "testorg1" p.Org = "testorg1"
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session) err := p.getEmail(context.Background(), session)
assert.Equal(t, nil, err) assert.NoError(t, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", session.Email)
} }
func TestGitHubProviderGetEmailAddressWithWriteAccessToPublicRepo(t *testing.T) { func TestGitHubProvider_getEmailWithWriteAccessToPublicRepo(t *testing.T) {
b := testGitHubBackend(map[string][]string{ b := testGitHubBackend(map[string][]string{
"/repo/oauth2-proxy/oauth2-proxy": {`{"permissions": {"pull": true, "push": true}, "private": false}`}, "/repo/oauth2-proxy/oauth2-proxy": {`{"permissions": {"pull": true, "push": true}, "private": false}`},
"/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`}, "/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`},
@ -170,12 +170,12 @@ func TestGitHubProviderGetEmailAddressWithWriteAccessToPublicRepo(t *testing.T)
p.SetRepo("oauth2-proxy/oauth2-proxy", "") p.SetRepo("oauth2-proxy/oauth2-proxy", "")
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session) err := p.getEmail(context.Background(), session)
assert.Equal(t, nil, err) assert.NoError(t, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", session.Email)
} }
func TestGitHubProviderGetEmailAddressWithReadOnlyAccessToPrivateRepo(t *testing.T) { func TestGitHubProvider_getEmailWithReadOnlyAccessToPrivateRepo(t *testing.T) {
b := testGitHubBackend(map[string][]string{ b := testGitHubBackend(map[string][]string{
"/repo/oauth2-proxy/oauth2-proxy": {`{"permissions": {"pull": true, "push": false}, "private": true}`}, "/repo/oauth2-proxy/oauth2-proxy": {`{"permissions": {"pull": true, "push": false}, "private": true}`},
"/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`}, "/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`},
@ -187,12 +187,12 @@ func TestGitHubProviderGetEmailAddressWithReadOnlyAccessToPrivateRepo(t *testing
p.SetRepo("oauth2-proxy/oauth2-proxy", "") p.SetRepo("oauth2-proxy/oauth2-proxy", "")
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session) err := p.getEmail(context.Background(), session)
assert.Equal(t, nil, err) assert.NoError(t, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", session.Email)
} }
func TestGitHubProviderGetEmailAddressWithWriteAccessToPrivateRepo(t *testing.T) { func TestGitHubProvider_getEmailWithWriteAccessToPrivateRepo(t *testing.T) {
b := testGitHubBackend(map[string][]string{ b := testGitHubBackend(map[string][]string{
"/repo/oauth2-proxy/oauth2-proxy": {`{"permissions": {"pull": true, "push": true}, "private": true}`}, "/repo/oauth2-proxy/oauth2-proxy": {`{"permissions": {"pull": true, "push": true}, "private": true}`},
"/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`}, "/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`},
@ -204,14 +204,14 @@ func TestGitHubProviderGetEmailAddressWithWriteAccessToPrivateRepo(t *testing.T)
p.SetRepo("oauth2-proxy/oauth2-proxy", "") p.SetRepo("oauth2-proxy/oauth2-proxy", "")
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session) err := p.getEmail(context.Background(), session)
assert.Equal(t, nil, err) assert.NoError(t, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", session.Email)
} }
func TestGitHubProviderGetEmailAddressWithNoAccessToPrivateRepo(t *testing.T) { func TestGitHubProvider_getEmailWithNoAccessToPrivateRepo(t *testing.T) {
b := testGitHubBackend(map[string][]string{ b := testGitHubBackend(map[string][]string{
"/repo/oauth2-proxy/oauth2-proxy": {}, "/repo/oauth2-proxy/oauth2-proxy": {`{}`},
}) })
defer b.Close() defer b.Close()
@ -220,12 +220,12 @@ func TestGitHubProviderGetEmailAddressWithNoAccessToPrivateRepo(t *testing.T) {
p.SetRepo("oauth2-proxy/oauth2-proxy", "") p.SetRepo("oauth2-proxy/oauth2-proxy", "")
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session) err := p.getEmail(context.Background(), session)
assert.NotEqual(t, nil, err) assert.NoError(t, err)
assert.Equal(t, "", email) assert.Empty(t, session.Email)
} }
func TestGitHubProviderGetEmailAddressWithToken(t *testing.T) { func TestGitHubProvider_getEmailWithToken(t *testing.T) {
b := testGitHubBackend(map[string][]string{ b := testGitHubBackend(map[string][]string{
"/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`}, "/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`},
}) })
@ -236,14 +236,14 @@ func TestGitHubProviderGetEmailAddressWithToken(t *testing.T) {
p.SetRepo("oauth2-proxy/oauth2-proxy", "token") p.SetRepo("oauth2-proxy/oauth2-proxy", "token")
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session) err := p.getEmail(context.Background(), session)
assert.Equal(t, nil, err) assert.NoError(t, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", session.Email)
} }
// Note that trying to trigger the "failed building request" case is not // Note that trying to trigger the "failed building request" case is not
// practical, since the only way it can fail is if the URL fails to parse. // practical, since the only way it can fail is if the URL fails to parse.
func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) { func TestGitHubProvider_getEmailFailedRequest(t *testing.T) {
b := testGitHubBackend(map[string][]string{}) b := testGitHubBackend(map[string][]string{})
defer b.Close() defer b.Close()
@ -254,12 +254,12 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
session := &sessions.SessionState{AccessToken: "unexpected_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(context.Background(), session) err := p.getEmail(context.Background(), session)
assert.NotEqual(t, nil, err) assert.Error(t, err)
assert.Equal(t, "", email) assert.Empty(t, session.Email)
} }
func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { func TestGitHubProvider_getEmailNotPresentInPayload(t *testing.T) {
b := testGitHubBackend(map[string][]string{ b := testGitHubBackend(map[string][]string{
"/user/emails": {`{"foo": "bar"}`}, "/user/emails": {`{"foo": "bar"}`},
}) })
@ -269,12 +269,12 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session) err := p.getEmail(context.Background(), session)
assert.NotEqual(t, nil, err) assert.Error(t, err)
assert.Equal(t, "", email) assert.Empty(t, session.Email)
} }
func TestGitHubProviderGetUserName(t *testing.T) { func TestGitHubProvider_getUser(t *testing.T) {
b := testGitHubBackend(map[string][]string{ b := testGitHubBackend(map[string][]string{
"/user": {`{"email": "michael.bland@gsa.gov", "login": "mbland"}`}, "/user": {`{"email": "michael.bland@gsa.gov", "login": "mbland"}`},
}) })
@ -284,12 +284,12 @@ func TestGitHubProviderGetUserName(t *testing.T) {
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetUserName(context.Background(), session) err := p.getUser(context.Background(), session)
assert.Equal(t, nil, err) assert.NoError(t, err)
assert.Equal(t, "mbland", email) assert.Equal(t, "mbland", session.User)
} }
func TestGitHubProviderGetUserNameWithRepoAndToken(t *testing.T) { func TestGitHubProvider_getUserWithRepoAndToken(t *testing.T) {
b := testGitHubBackend(map[string][]string{ b := testGitHubBackend(map[string][]string{
"/user": {`{"email": "michael.bland@gsa.gov", "login": "mbland"}`}, "/user": {`{"email": "michael.bland@gsa.gov", "login": "mbland"}`},
"/repos/oauth2-proxy/oauth2-proxy/collaborators/mbland": {""}, "/repos/oauth2-proxy/oauth2-proxy/collaborators/mbland": {""},
@ -301,12 +301,12 @@ func TestGitHubProviderGetUserNameWithRepoAndToken(t *testing.T) {
p.SetRepo("oauth2-proxy/oauth2-proxy", "token") p.SetRepo("oauth2-proxy/oauth2-proxy", "token")
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetUserName(context.Background(), session) err := p.getUser(context.Background(), session)
assert.Equal(t, nil, err) assert.NoError(t, err)
assert.Equal(t, "mbland", email) assert.Equal(t, "mbland", session.User)
} }
func TestGitHubProviderGetUserNameWithRepoAndTokenWithoutPushAccess(t *testing.T) { func TestGitHubProvider_getUserWithRepoAndTokenWithoutPushAccess(t *testing.T) {
b := testGitHubBackend(map[string][]string{}) b := testGitHubBackend(map[string][]string{})
defer b.Close() defer b.Close()
@ -315,12 +315,12 @@ func TestGitHubProviderGetUserNameWithRepoAndTokenWithoutPushAccess(t *testing.T
p.SetRepo("oauth2-proxy/oauth2-proxy", "token") p.SetRepo("oauth2-proxy/oauth2-proxy", "token")
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetUserName(context.Background(), session) err := p.getUser(context.Background(), session)
assert.NotEqual(t, nil, err) assert.Error(t, err)
assert.Equal(t, "", email) assert.Empty(t, session.User)
} }
func TestGitHubProviderGetEmailAddressWithUsername(t *testing.T) { func TestGitHubProvider_getEmailWithUsername(t *testing.T) {
b := testGitHubBackend(map[string][]string{ b := testGitHubBackend(map[string][]string{
"/user": {`{"email": "michael.bland@gsa.gov", "login": "mbland"}`}, "/user": {`{"email": "michael.bland@gsa.gov", "login": "mbland"}`},
"/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`}, "/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`},
@ -332,12 +332,12 @@ func TestGitHubProviderGetEmailAddressWithUsername(t *testing.T) {
p.SetUsers([]string{"mbland", "octocat"}) p.SetUsers([]string{"mbland", "octocat"})
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session) err := p.getEmail(context.Background(), session)
assert.Equal(t, nil, err) assert.NoError(t, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", session.Email)
} }
func TestGitHubProviderGetEmailAddressWithNotAllowedUsername(t *testing.T) { func TestGitHubProvider_getEmailWithNotAllowedUsername(t *testing.T) {
b := testGitHubBackend(map[string][]string{ b := testGitHubBackend(map[string][]string{
"/user": {`{"email": "michael.bland@gsa.gov", "login": "mbland"}`}, "/user": {`{"email": "michael.bland@gsa.gov", "login": "mbland"}`},
"/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`}, "/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`},
@ -349,12 +349,12 @@ func TestGitHubProviderGetEmailAddressWithNotAllowedUsername(t *testing.T) {
p.SetUsers([]string{"octocat"}) p.SetUsers([]string{"octocat"})
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session) err := p.getEmail(context.Background(), session)
assert.NotEqual(t, nil, err) assert.Error(t, err)
assert.Equal(t, "", email) assert.Empty(t, session.Email)
} }
func TestGitHubProviderGetEmailAddressWithUsernameAndNotBelongToOrg(t *testing.T) { func TestGitHubProvider_getEmailWithUsernameAndNotBelongToOrg(t *testing.T) {
b := testGitHubBackend(map[string][]string{ b := testGitHubBackend(map[string][]string{
"/user": {`{"email": "michael.bland@gsa.gov", "login": "mbland"}`}, "/user": {`{"email": "michael.bland@gsa.gov", "login": "mbland"}`},
"/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`}, "/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`},
@ -371,16 +371,16 @@ func TestGitHubProviderGetEmailAddressWithUsernameAndNotBelongToOrg(t *testing.T
p.SetUsers([]string{"mbland"}) p.SetUsers([]string{"mbland"})
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session) err := p.getEmail(context.Background(), session)
assert.Equal(t, nil, err) assert.NoError(t, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", session.Email)
} }
func TestGitHubProviderGetEmailAddressWithUsernameAndNoAccessToPrivateRepo(t *testing.T) { func TestGitHubProvider_getEmailWithUsernameAndNoAccessToPrivateRepo(t *testing.T) {
b := testGitHubBackend(map[string][]string{ b := testGitHubBackend(map[string][]string{
"/user": {`{"email": "michael.bland@gsa.gov", "login": "mbland"}`}, "/user": {`{"email": "michael.bland@gsa.gov", "login": "mbland"}`},
"/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`}, "/user/emails": {`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`},
"/repo/oauth2-proxy/oauth2-proxy": {}, "/repo/oauth2-proxy/oauth2-proxy": {`{}`},
}) })
defer b.Close() defer b.Close()
@ -390,7 +390,7 @@ func TestGitHubProviderGetEmailAddressWithUsernameAndNoAccessToPrivateRepo(t *te
p.SetUsers([]string{"mbland"}) p.SetUsers([]string{"mbland"})
session := CreateAuthorizedSession() session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session) err := p.getEmail(context.Background(), session)
assert.Equal(t, nil, err) assert.NoError(t, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", session.Email)
} }

View File

@ -3,7 +3,6 @@ package providers
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"time" "time"
oidc "github.com/coreos/go-oidc" oidc "github.com/coreos/go-oidc"
@ -17,8 +16,6 @@ type GitLabProvider struct {
*ProviderData *ProviderData
Groups []string Groups []string
EmailDomains []string
Verifier *oidc.IDTokenVerifier Verifier *oidc.IDTokenVerifier
AllowUnverifiedEmail bool AllowUnverifiedEmail bool
} }
@ -168,20 +165,6 @@ func (p *GitLabProvider) verifyGroupMembership(userInfo *gitlabUserInfo) error {
return fmt.Errorf("user is not a member of '%s'", p.Groups) return fmt.Errorf("user is not a member of '%s'", p.Groups)
} }
func (p *GitLabProvider) verifyEmailDomain(userInfo *gitlabUserInfo) error {
if len(p.EmailDomains) == 0 || p.EmailDomains[0] == "*" {
return nil
}
for _, domain := range p.EmailDomains {
if strings.HasSuffix(userInfo.Email, domain) {
return nil
}
}
return fmt.Errorf("user email is not one of the valid domains '%v'", p.EmailDomains)
}
func (p *GitLabProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) { func (p *GitLabProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
rawIDToken, ok := token.Extra("id_token").(string) rawIDToken, ok := token.Extra("id_token").(string)
if !ok { if !ok {
@ -211,39 +194,27 @@ func (p *GitLabProvider) ValidateSessionState(ctx context.Context, s *sessions.S
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *GitLabProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { func (p *GitLabProvider) EnrichSessionState(ctx context.Context, s *sessions.SessionState) error {
// Retrieve user info // Retrieve user info
userInfo, err := p.getUserInfo(ctx, s) userInfo, err := p.getUserInfo(ctx, s)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to retrieve user info: %v", err) return fmt.Errorf("failed to retrieve user info: %v", err)
} }
// Check if email is verified // Check if email is verified
if !p.AllowUnverifiedEmail && !userInfo.EmailVerified { if !p.AllowUnverifiedEmail && !userInfo.EmailVerified {
return "", fmt.Errorf("user email is not verified") return fmt.Errorf("user email is not verified")
}
// Check if email has valid domain
err = p.verifyEmailDomain(userInfo)
if err != nil {
return "", fmt.Errorf("email domain check failed: %v", err)
} }
// Check group membership // Check group membership
// TODO (@NickMeves) - Refactor to Authorize
err = p.verifyGroupMembership(userInfo) err = p.verifyGroupMembership(userInfo)
if err != nil { if err != nil {
return "", fmt.Errorf("group membership check failed: %v", err) return fmt.Errorf("group membership check failed: %v", err)
} }
return userInfo.Email, nil s.User = userInfo.Username
} s.Email = userInfo.Email
// GetUserName returns the Account user name return nil
func (p *GitLabProvider) GetUserName(ctx context.Context, s *sessions.SessionState) (string, error) {
userInfo, err := p.getUserInfo(ctx, s)
if err != nil {
return "", fmt.Errorf("failed to retrieve user info: %v", err)
}
return userInfo.Username, nil
} }

View File

@ -64,8 +64,8 @@ func TestGitLabProviderBadToken(t *testing.T) {
p := testGitLabProvider(bURL.Host) p := testGitLabProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "unexpected_gitlab_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_gitlab_access_token"}
_, err := p.GetEmailAddress(context.Background(), session) err := p.EnrichSessionState(context.Background(), session)
assert.NotEqual(t, nil, err) assert.Error(t, err)
} }
func TestGitLabProviderUnverifiedEmailDenied(t *testing.T) { func TestGitLabProviderUnverifiedEmailDenied(t *testing.T) {
@ -76,8 +76,8 @@ func TestGitLabProviderUnverifiedEmailDenied(t *testing.T) {
p := testGitLabProvider(bURL.Host) p := testGitLabProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
_, err := p.GetEmailAddress(context.Background(), session) err := p.EnrichSessionState(context.Background(), session)
assert.NotEqual(t, nil, err) assert.Error(t, err)
} }
func TestGitLabProviderUnverifiedEmailAllowed(t *testing.T) { func TestGitLabProviderUnverifiedEmailAllowed(t *testing.T) {
@ -89,9 +89,9 @@ func TestGitLabProviderUnverifiedEmailAllowed(t *testing.T) {
p.AllowUnverifiedEmail = true p.AllowUnverifiedEmail = true
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
email, err := p.GetEmailAddress(context.Background(), session) err := p.EnrichSessionState(context.Background(), session)
assert.Equal(t, nil, err) assert.NoError(t, err)
assert.Equal(t, "foo@bar.com", email) assert.Equal(t, "foo@bar.com", session.Email)
} }
func TestGitLabProviderUsername(t *testing.T) { func TestGitLabProviderUsername(t *testing.T) {
@ -103,9 +103,9 @@ func TestGitLabProviderUsername(t *testing.T) {
p.AllowUnverifiedEmail = true p.AllowUnverifiedEmail = true
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
username, err := p.GetUserName(context.Background(), session) err := p.EnrichSessionState(context.Background(), session)
assert.Equal(t, nil, err) assert.NoError(t, err)
assert.Equal(t, "FooBar", username) assert.Equal(t, "FooBar", session.User)
} }
func TestGitLabProviderGroupMembershipValid(t *testing.T) { func TestGitLabProviderGroupMembershipValid(t *testing.T) {
@ -118,9 +118,9 @@ func TestGitLabProviderGroupMembershipValid(t *testing.T) {
p.Groups = []string{"foo"} p.Groups = []string{"foo"}
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
email, err := p.GetEmailAddress(context.Background(), session) err := p.EnrichSessionState(context.Background(), session)
assert.Equal(t, nil, err) assert.NoError(t, err)
assert.Equal(t, "foo@bar.com", email) assert.Equal(t, "FooBar", session.User)
} }
func TestGitLabProviderGroupMembershipMissing(t *testing.T) { func TestGitLabProviderGroupMembershipMissing(t *testing.T) {
@ -133,35 +133,6 @@ func TestGitLabProviderGroupMembershipMissing(t *testing.T) {
p.Groups = []string{"baz"} p.Groups = []string{"baz"}
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
_, err := p.GetEmailAddress(context.Background(), session) err := p.EnrichSessionState(context.Background(), session)
assert.NotEqual(t, nil, err) assert.Error(t, err)
}
func TestGitLabProviderEmailDomainValid(t *testing.T) {
b := testGitLabBackend()
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(bURL.Host)
p.AllowUnverifiedEmail = true
p.EmailDomains = []string{"bar.com"}
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err)
assert.Equal(t, "foo@bar.com", email)
}
func TestGitLabProviderEmailDomainInvalid(t *testing.T) {
b := testGitLabBackend()
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(bURL.Host)
p.AllowUnverifiedEmail = true
p.EmailDomains = []string{"baz.com"}
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
_, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err)
} }

View File

@ -14,7 +14,13 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
) )
var _ Provider = (*ProviderData)(nil) var (
// ErrNotImplemented is returned when a provider did not override a default
// implementation method that doesn't have sensible defaults
ErrNotImplemented = errors.New("not implemented")
_ Provider = (*ProviderData)(nil)
)
// Redeem provides a default implementation of the OAuth2 token redemption process // Redeem provides a default implementation of the OAuth2 token redemption process
func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) { func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
@ -81,21 +87,23 @@ func (p *ProviderData) GetLoginURL(redirectURI, state string) string {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *ProviderData) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { // DEPRECATED: Migrate to EnrichSessionState
return "", errors.New("not implemented") func (p *ProviderData) GetEmailAddress(_ context.Context, _ *sessions.SessionState) (string, error) {
} return "", ErrNotImplemented
// GetUserName returns the Account username
func (p *ProviderData) GetUserName(ctx context.Context, s *sessions.SessionState) (string, error) {
return "", errors.New("not implemented")
} }
// ValidateGroup validates that the provided email exists in the configured provider // ValidateGroup validates that the provided email exists in the configured provider
// email group(s). // email group(s).
func (p *ProviderData) ValidateGroup(email string) bool { func (p *ProviderData) ValidateGroup(_ string) bool {
return true return true
} }
// EnrichSessionState is called after Redeem to allow providers to enrich session fields
// such as User, Email, Groups with provider specific API calls.
func (p *ProviderData) EnrichSessionState(_ context.Context, _ *sessions.SessionState) error {
return nil
}
// ValidateSessionState validates the AccessToken // ValidateSessionState validates the AccessToken
func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool {
return validateToken(ctx, p, s.AccessToken, nil) return validateToken(ctx, p, s.AccessToken, nil)
@ -103,12 +111,12 @@ func (p *ProviderData) ValidateSessionState(ctx context.Context, s *sessions.Ses
// RefreshSessionIfNeeded should refresh the user's session if required and // RefreshSessionIfNeeded should refresh the user's session if required and
// do nothing if a refresh is not required // do nothing if a refresh is not required
func (p *ProviderData) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) {
return false, nil return false, nil
} }
// CreateSessionStateFromBearerToken should be implemented to allow providers // CreateSessionStateFromBearerToken should be implemented to allow providers
// to convert ID tokens into sessions // to convert ID tokens into sessions
func (p *ProviderData) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) { func (p *ProviderData) CreateSessionStateFromBearerToken(_ context.Context, _ string, _ *oidc.IDToken) (*sessions.SessionState, error) {
return nil, errors.New("not implemented") return nil, ErrNotImplemented
} }

View File

@ -47,3 +47,9 @@ func TestAcrValuesConfigured(t *testing.T) {
result := p.GetLoginURL("https://my.test.app/oauth", "") result := p.GetLoginURL("https://my.test.app/oauth", "")
assert.Contains(t, result, "acr_values=testValue") assert.Contains(t, result, "acr_values=testValue")
} }
func TestEnrichSessionState(t *testing.T) {
p := &ProviderData{}
s := &sessions.SessionState{}
assert.NoError(t, p.EnrichSessionState(context.Background(), s))
}

View File

@ -10,10 +10,11 @@ import (
// Provider represents an upstream identity provider implementation // Provider represents an upstream identity provider implementation
type Provider interface { type Provider interface {
Data() *ProviderData Data() *ProviderData
// DEPRECATED: Migrate to EnrichSessionState
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
GetUserName(ctx context.Context, s *sessions.SessionState) (string, error)
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
ValidateGroup(string) bool ValidateGroup(string) bool
EnrichSessionState(ctx context.Context, s *sessions.SessionState) error
ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool
GetLoginURL(redirectURI, finalRedirect string) string GetLoginURL(redirectURI, finalRedirect string) string
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)