Generalize and extend default CreateSessionFromToken
This commit is contained in:
		
							parent
							
								
									44fa8316a1
								
							
						
					
					
						commit
						22f60e9b63
					
				
							
								
								
									
										112
									
								
								oauthproxy.go
								
								
								
								
							
							
						
						
									
										112
									
								
								oauthproxy.go
								
								
								
								
							|  | @ -13,7 +13,6 @@ import ( | |||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/coreos/go-oidc" | ||||
| 	"github.com/justinas/alice" | ||||
| 	ipapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/ip" | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
|  | @ -78,36 +77,34 @@ type OAuthProxy struct { | |||
| 	AuthOnlyPath      string | ||||
| 	UserInfoPath      string | ||||
| 
 | ||||
| 	allowedRoutes           []allowedRoute | ||||
| 	redirectURL             *url.URL // the url to receive requests at
 | ||||
| 	whitelistDomains        []string | ||||
| 	provider                providers.Provider | ||||
| 	providerNameOverride    string | ||||
| 	sessionStore            sessionsapi.SessionStore | ||||
| 	ProxyPrefix             string | ||||
| 	SignInMessage           string | ||||
| 	basicAuthValidator      basic.Validator | ||||
| 	displayHtpasswdForm     bool | ||||
| 	serveMux                http.Handler | ||||
| 	SetXAuthRequest         bool | ||||
| 	PassBasicAuth           bool | ||||
| 	SetBasicAuth            bool | ||||
| 	SkipProviderButton      bool | ||||
| 	PassUserHeaders         bool | ||||
| 	BasicAuthPassword       string | ||||
| 	PassAccessToken         bool | ||||
| 	SetAuthorization        bool | ||||
| 	PassAuthorization       bool | ||||
| 	PreferEmailToUser       bool | ||||
| 	skipAuthPreflight       bool | ||||
| 	skipJwtBearerTokens     bool | ||||
| 	mainJwtBearerVerifier   *oidc.IDTokenVerifier | ||||
| 	extraJwtBearerVerifiers []*oidc.IDTokenVerifier | ||||
| 	templates               *template.Template | ||||
| 	realClientIPParser      ipapi.RealClientIPParser | ||||
| 	trustedIPs              *ip.NetSet | ||||
| 	Banner                  string | ||||
| 	Footer                  string | ||||
| 	allowedRoutes        []allowedRoute | ||||
| 	redirectURL          *url.URL // the url to receive requests at
 | ||||
| 	whitelistDomains     []string | ||||
| 	provider             providers.Provider | ||||
| 	providerNameOverride string | ||||
| 	sessionStore         sessionsapi.SessionStore | ||||
| 	ProxyPrefix          string | ||||
| 	SignInMessage        string | ||||
| 	basicAuthValidator   basic.Validator | ||||
| 	displayHtpasswdForm  bool | ||||
| 	serveMux             http.Handler | ||||
| 	SetXAuthRequest      bool | ||||
| 	PassBasicAuth        bool | ||||
| 	SetBasicAuth         bool | ||||
| 	SkipProviderButton   bool | ||||
| 	PassUserHeaders      bool | ||||
| 	BasicAuthPassword    string | ||||
| 	PassAccessToken      bool | ||||
| 	SetAuthorization     bool | ||||
| 	PassAuthorization    bool | ||||
| 	PreferEmailToUser    bool | ||||
| 	skipAuthPreflight    bool | ||||
| 	skipJwtBearerTokens  bool | ||||
| 	templates            *template.Template | ||||
| 	realClientIPParser   ipapi.RealClientIPParser | ||||
| 	trustedIPs           *ip.NetSet | ||||
| 	Banner               string | ||||
| 	Footer               string | ||||
| 
 | ||||
| 	sessionChain alice.Chain | ||||
| 	headersChain alice.Chain | ||||
|  | @ -202,25 +199,23 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | |||
| 		AuthOnlyPath:      fmt.Sprintf("%s/auth", opts.ProxyPrefix), | ||||
| 		UserInfoPath:      fmt.Sprintf("%s/userinfo", opts.ProxyPrefix), | ||||
| 
 | ||||
| 		ProxyPrefix:             opts.ProxyPrefix, | ||||
| 		provider:                opts.GetProvider(), | ||||
| 		providerNameOverride:    opts.ProviderName, | ||||
| 		sessionStore:            sessionStore, | ||||
| 		serveMux:                upstreamProxy, | ||||
| 		redirectURL:             redirectURL, | ||||
| 		allowedRoutes:           allowedRoutes, | ||||
| 		whitelistDomains:        opts.WhitelistDomains, | ||||
| 		skipAuthPreflight:       opts.SkipAuthPreflight, | ||||
| 		skipJwtBearerTokens:     opts.SkipJwtBearerTokens, | ||||
| 		mainJwtBearerVerifier:   opts.GetOIDCVerifier(), | ||||
| 		extraJwtBearerVerifiers: opts.GetJWTBearerVerifiers(), | ||||
| 		realClientIPParser:      opts.GetRealClientIPParser(), | ||||
| 		SkipProviderButton:      opts.SkipProviderButton, | ||||
| 		templates:               templates, | ||||
| 		trustedIPs:              trustedIPs, | ||||
| 		Banner:                  opts.Banner, | ||||
| 		Footer:                  opts.Footer, | ||||
| 		SignInMessage:           buildSignInMessage(opts), | ||||
| 		ProxyPrefix:          opts.ProxyPrefix, | ||||
| 		provider:             opts.GetProvider(), | ||||
| 		providerNameOverride: opts.ProviderName, | ||||
| 		sessionStore:         sessionStore, | ||||
| 		serveMux:             upstreamProxy, | ||||
| 		redirectURL:          redirectURL, | ||||
| 		allowedRoutes:        allowedRoutes, | ||||
| 		whitelistDomains:     opts.WhitelistDomains, | ||||
| 		skipAuthPreflight:    opts.SkipAuthPreflight, | ||||
| 		skipJwtBearerTokens:  opts.SkipJwtBearerTokens, | ||||
| 		realClientIPParser:   opts.GetRealClientIPParser(), | ||||
| 		SkipProviderButton:   opts.SkipProviderButton, | ||||
| 		templates:            templates, | ||||
| 		trustedIPs:           trustedIPs, | ||||
| 		Banner:               opts.Banner, | ||||
| 		Footer:               opts.Footer, | ||||
| 		SignInMessage:        buildSignInMessage(opts), | ||||
| 
 | ||||
| 		basicAuthValidator:  basicAuthValidator, | ||||
| 		displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm, | ||||
|  | @ -266,22 +261,13 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt | |||
| 	chain := alice.New() | ||||
| 
 | ||||
| 	if opts.SkipJwtBearerTokens { | ||||
| 		sessionLoaders := []middlewareapi.TokenToSessionLoader{} | ||||
| 		if opts.GetOIDCVerifier() != nil { | ||||
| 			sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{ | ||||
| 				Verifier: func(ctx context.Context, token string) (interface{}, error) { | ||||
| 					return opts.GetOIDCVerifier().Verify(ctx, token) | ||||
| 				}, | ||||
| 				TokenToSession: opts.GetProvider().CreateSessionFromToken, | ||||
| 			}) | ||||
| 		sessionLoaders := []middlewareapi.TokenToSessionFunc{ | ||||
| 			opts.GetProvider().CreateSessionFromToken, | ||||
| 		} | ||||
| 
 | ||||
| 		for _, verifier := range opts.GetJWTBearerVerifiers() { | ||||
| 			sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{ | ||||
| 				Verifier: func(ctx context.Context, token string) (interface{}, error) { | ||||
| 					return verifier.Verify(ctx, token) | ||||
| 				}, | ||||
| 			}) | ||||
| 			sessionLoaders = append(sessionLoaders, | ||||
| 				middlewareapi.CreateTokenToSessionFunc(verifier.Verify)) | ||||
| 		} | ||||
| 
 | ||||
| 		chain = chain.Append(middleware.NewJwtSessionLoader(sessionLoaders)) | ||||
|  |  | |||
|  | @ -2,25 +2,57 @@ package middleware | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 
 | ||||
| 	"github.com/coreos/go-oidc" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
| ) | ||||
| 
 | ||||
| // TokenToSessionFunc takes a rawIDToken and an idToken and converts it into a
 | ||||
| // SessionState.
 | ||||
| type TokenToSessionFunc func(ctx context.Context, token string, verify VerifyFunc) (*sessionsapi.SessionState, error) | ||||
| // TokenToSessionFunc takes a raw ID Token and converts it into a SessionState.
 | ||||
| type TokenToSessionFunc func(ctx context.Context, token string) (*sessionsapi.SessionState, error) | ||||
| 
 | ||||
| // VerifyFunc takes a raw bearer token and verifies it
 | ||||
| type VerifyFunc func(ctx context.Context, token string) (interface{}, error) | ||||
| // VerifyFunc takes a raw bearer token and verifies it returning the converted
 | ||||
| // oidc.IDToken representation of the token.
 | ||||
| type VerifyFunc func(ctx context.Context, token string) (*oidc.IDToken, error) | ||||
| 
 | ||||
| // TokenToSessionLoader pairs a token verifier with the correct converter function
 | ||||
| // to convert the ID Token to a SessionState.
 | ||||
| type TokenToSessionLoader struct { | ||||
| 	// Verifier is used to verify that the ID Token was signed by the claimed issuer
 | ||||
| 	// and that the token has not been tampered with.
 | ||||
| 	Verifier VerifyFunc | ||||
| // CreateTokenToSessionFunc provides a handler that is a default implementation
 | ||||
| // for converting a JWT into a session.
 | ||||
| func CreateTokenToSessionFunc(verify VerifyFunc) TokenToSessionFunc { | ||||
| 	return func(ctx context.Context, token string) (*sessionsapi.SessionState, error) { | ||||
| 		var claims struct { | ||||
| 			Subject           string `json:"sub"` | ||||
| 			Email             string `json:"email"` | ||||
| 			Verified          *bool  `json:"email_verified"` | ||||
| 			PreferredUsername string `json:"preferred_username"` | ||||
| 		} | ||||
| 
 | ||||
| 	// TokenToSession converts a raw bearer token to a SessionState.
 | ||||
| 	// (Optional) If not set a default basic implementation is used.
 | ||||
| 	TokenToSession TokenToSessionFunc | ||||
| 		idToken, err := verify(ctx, token) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 
 | ||||
| 		if err := idToken.Claims(&claims); err != nil { | ||||
| 			return nil, fmt.Errorf("failed to parse bearer token claims: %v", err) | ||||
| 		} | ||||
| 
 | ||||
| 		if claims.Email == "" { | ||||
| 			claims.Email = claims.Subject | ||||
| 		} | ||||
| 
 | ||||
| 		if claims.Verified != nil && !*claims.Verified { | ||||
| 			return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) | ||||
| 		} | ||||
| 
 | ||||
| 		newSession := &sessionsapi.SessionState{ | ||||
| 			Email:             claims.Email, | ||||
| 			User:              claims.Subject, | ||||
| 			PreferredUsername: claims.PreferredUsername, | ||||
| 			AccessToken:       token, | ||||
| 			IDToken:           token, | ||||
| 			RefreshToken:      "", | ||||
| 			ExpiresOn:         &idToken.Expiry, | ||||
| 		} | ||||
| 
 | ||||
| 		return newSession, nil | ||||
| 	} | ||||
| } | ||||
|  |  | |||
|  | @ -1,12 +1,10 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"regexp" | ||||
| 
 | ||||
| 	"github.com/coreos/go-oidc" | ||||
| 	"github.com/justinas/alice" | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
|  | @ -16,16 +14,7 @@ import ( | |||
| 
 | ||||
| const jwtRegexFormat = `^ey[IJ][a-zA-Z0-9_-]*\.ey[IJ][a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$` | ||||
| 
 | ||||
| func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionLoader) alice.Constructor { | ||||
| 	for i, loader := range sessionLoaders { | ||||
| 		if loader.TokenToSession == nil { | ||||
| 			sessionLoaders[i] = middlewareapi.TokenToSessionLoader{ | ||||
| 				Verifier:       loader.Verifier, | ||||
| 				TokenToSession: createSessionFromToken, | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionFunc) alice.Constructor { | ||||
| 	js := &jwtSessionLoader{ | ||||
| 		jwtRegex:       regexp.MustCompile(jwtRegexFormat), | ||||
| 		sessionLoaders: sessionLoaders, | ||||
|  | @ -37,7 +26,7 @@ func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionLoader) al | |||
| // Authorization headers.
 | ||||
| type jwtSessionLoader struct { | ||||
| 	jwtRegex       *regexp.Regexp | ||||
| 	sessionLoaders []middlewareapi.TokenToSessionLoader | ||||
| 	sessionLoaders []middlewareapi.TokenToSessionFunc | ||||
| } | ||||
| 
 | ||||
| // loadSession attempts to load a session from a JWT stored in an Authorization
 | ||||
|  | @ -83,7 +72,7 @@ func (j *jwtSessionLoader) getJwtSession(req *http.Request) (*sessionsapi.Sessio | |||
| 
 | ||||
| 	errs := []error{fmt.Errorf("unable to verify jwt token: %q", req.Header.Get("Authorization"))} | ||||
| 	for _, loader := range j.sessionLoaders { | ||||
| 		session, err := loader.TokenToSession(req.Context(), token, loader.Verifier) | ||||
| 		session, err := loader(req.Context(), token) | ||||
| 		if err == nil { | ||||
| 			return session, nil | ||||
| 		} else { | ||||
|  | @ -135,48 +124,3 @@ func (j *jwtSessionLoader) getBasicToken(token string) (string, error) { | |||
| 
 | ||||
| 	return "", fmt.Errorf("invalid basic auth token found in authorization header") | ||||
| } | ||||
| 
 | ||||
| // createSessionFromToken is a default implementation for converting
 | ||||
| // a JWT into a session state.
 | ||||
| func createSessionFromToken(ctx context.Context, token string, verify middlewareapi.VerifyFunc) (*sessionsapi.SessionState, error) { | ||||
| 	var claims struct { | ||||
| 		Subject           string `json:"sub"` | ||||
| 		Email             string `json:"email"` | ||||
| 		Verified          *bool  `json:"email_verified"` | ||||
| 		PreferredUsername string `json:"preferred_username"` | ||||
| 	} | ||||
| 
 | ||||
| 	verifiedToken, err := verify(ctx, token) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	idToken, ok := verifiedToken.(*oidc.IDToken) | ||||
| 	if !ok { | ||||
| 		return nil, fmt.Errorf("failed to create IDToken from bearer token: %s", token) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := idToken.Claims(&claims); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to parse bearer token claims: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if claims.Email == "" { | ||||
| 		claims.Email = claims.Subject | ||||
| 	} | ||||
| 
 | ||||
| 	if claims.Verified != nil && !*claims.Verified { | ||||
| 		return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) | ||||
| 	} | ||||
| 
 | ||||
| 	newSession := &sessionsapi.SessionState{ | ||||
| 		Email:             claims.Email, | ||||
| 		User:              claims.Subject, | ||||
| 		PreferredUsername: claims.PreferredUsername, | ||||
| 		AccessToken:       token, | ||||
| 		IDToken:           token, | ||||
| 		RefreshToken:      "", | ||||
| 		ExpiresOn:         &idToken.Expiry, | ||||
| 	} | ||||
| 
 | ||||
| 	return newSession, nil | ||||
| } | ||||
|  |  | |||
|  | @ -26,7 +26,7 @@ import ( | |||
| type noOpKeySet struct { | ||||
| } | ||||
| 
 | ||||
| func (noOpKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { | ||||
| func (noOpKeySet) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) { | ||||
| 	splitStrings := strings.Split(jwt, ".") | ||||
| 	payloadString := splitStrings[1] | ||||
| 	return base64.RawURLEncoding.DecodeString(payloadString) | ||||
|  | @ -78,16 +78,14 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | |||
| 		const nonVerifiedToken = validToken | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			verifier = func(ctx context.Context, token string) (interface{}, error) { | ||||
| 				return oidc.NewVerifier( | ||||
| 					"https://issuer.example.com", | ||||
| 					noOpKeySet{}, | ||||
| 					&oidc.Config{ | ||||
| 						ClientID:        "https://test.myapp.com", | ||||
| 						SkipExpiryCheck: true, | ||||
| 					}, | ||||
| 				).Verify(ctx, token) | ||||
| 			} | ||||
| 			verifier = oidc.NewVerifier( | ||||
| 				"https://issuer.example.com", | ||||
| 				noOpKeySet{}, | ||||
| 				&oidc.Config{ | ||||
| 					ClientID:        "https://test.myapp.com", | ||||
| 					SkipExpiryCheck: true, | ||||
| 				}, | ||||
| 			).Verify | ||||
| 		}) | ||||
| 
 | ||||
| 		type jwtSessionLoaderTableInput struct { | ||||
|  | @ -110,10 +108,8 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | |||
| 
 | ||||
| 				rw := httptest.NewRecorder() | ||||
| 
 | ||||
| 				sessionLoaders := []middlewareapi.TokenToSessionLoader{ | ||||
| 					{ | ||||
| 						Verifier: verifier, | ||||
| 					}, | ||||
| 				sessionLoaders := []middlewareapi.TokenToSessionFunc{ | ||||
| 					middlewareapi.CreateTokenToSessionFunc(verifier), | ||||
| 				} | ||||
| 
 | ||||
| 				// Create the handler with a next handler that will capture the session
 | ||||
|  | @ -175,24 +171,19 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | |||
| 		const nonVerifiedToken = validToken | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			verifier := func(ctx context.Context, token string) (interface{}, error) { | ||||
| 				return oidc.NewVerifier( | ||||
| 					"https://issuer.example.com", | ||||
| 					noOpKeySet{}, | ||||
| 					&oidc.Config{ | ||||
| 						ClientID:        "https://test.myapp.com", | ||||
| 						SkipExpiryCheck: true, | ||||
| 					}, | ||||
| 				).Verify(ctx, token) | ||||
| 			} | ||||
| 			verifier := oidc.NewVerifier( | ||||
| 				"https://issuer.example.com", | ||||
| 				noOpKeySet{}, | ||||
| 				&oidc.Config{ | ||||
| 					ClientID:        "https://test.myapp.com", | ||||
| 					SkipExpiryCheck: true, | ||||
| 				}, | ||||
| 			).Verify | ||||
| 
 | ||||
| 			j = &jwtSessionLoader{ | ||||
| 				jwtRegex: regexp.MustCompile(jwtRegexFormat), | ||||
| 				sessionLoaders: []middlewareapi.TokenToSessionLoader{ | ||||
| 					{ | ||||
| 						Verifier:       verifier, | ||||
| 						TokenToSession: createSessionFromToken, | ||||
| 					}, | ||||
| 				sessionLoaders: []middlewareapi.TokenToSessionFunc{ | ||||
| 					middlewareapi.CreateTokenToSessionFunc(verifier), | ||||
| 				}, | ||||
| 			} | ||||
| 		}) | ||||
|  | @ -402,7 +393,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | |||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("createSessionFromToken", func() { | ||||
| 	Context("CreateTokenToSessionFunc", func() { | ||||
| 		ctx := context.Background() | ||||
| 		expiresFuture := time.Now().Add(time.Duration(5) * time.Minute) | ||||
| 		verified := true | ||||
|  | @ -414,7 +405,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | |||
| 			jwt.StandardClaims | ||||
| 		} | ||||
| 
 | ||||
| 		type createSessionStateTableInput struct { | ||||
| 		type tokenToSessionTableInput struct { | ||||
| 			idToken         idTokenClaims | ||||
| 			expectedErr     error | ||||
| 			expectedUser    string | ||||
|  | @ -423,8 +414,8 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | |||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("when creating a session from an IDToken", | ||||
| 			func(in createSessionStateTableInput) { | ||||
| 				verifier := func(ctx context.Context, token string) (interface{}, error) { | ||||
| 			func(in tokenToSessionTableInput) { | ||||
| 				verifier := func(ctx context.Context, token string) (*oidc.IDToken, error) { | ||||
| 					oidcVerifier := oidc.NewVerifier( | ||||
| 						"https://issuer.example.com", | ||||
| 						noOpKeySet{}, | ||||
|  | @ -443,7 +434,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | |||
| 				rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, in.idToken).SignedString(key) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 				session, err := createSessionFromToken(ctx, rawIDToken, verifier) | ||||
| 				session, err := middlewareapi.CreateTokenToSessionFunc(verifier)(ctx, rawIDToken) | ||||
| 				if in.expectedErr != nil { | ||||
| 					Expect(err).To(MatchError(in.expectedErr)) | ||||
| 					Expect(session).To(BeNil()) | ||||
|  | @ -459,7 +450,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | |||
| 				Expect(session.RefreshToken).To(BeEmpty()) | ||||
| 				Expect(session.PreferredUsername).To(BeEmpty()) | ||||
| 			}, | ||||
| 			Entry("with no email", createSessionStateTableInput{ | ||||
| 			Entry("with no email", tokenToSessionTableInput{ | ||||
| 				idToken: idTokenClaims{ | ||||
| 					StandardClaims: jwt.StandardClaims{ | ||||
| 						Audience:  "asdf1234", | ||||
|  | @ -476,7 +467,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | |||
| 				expectedEmail:   "123456789", | ||||
| 				expectedExpires: &expiresFuture, | ||||
| 			}), | ||||
| 			Entry("with a verified email", createSessionStateTableInput{ | ||||
| 			Entry("with a verified email", tokenToSessionTableInput{ | ||||
| 				idToken: idTokenClaims{ | ||||
| 					StandardClaims: jwt.StandardClaims{ | ||||
| 						Audience:  "asdf1234", | ||||
|  | @ -495,7 +486,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | |||
| 				expectedEmail:   "foo@example.com", | ||||
| 				expectedExpires: &expiresFuture, | ||||
| 			}), | ||||
| 			Entry("with a non-verified email", createSessionStateTableInput{ | ||||
| 			Entry("with a non-verified email", tokenToSessionTableInput{ | ||||
| 				idToken: idTokenClaims{ | ||||
| 					StandardClaims: jwt.StandardClaims{ | ||||
| 						Audience:  "asdf1234", | ||||
|  |  | |||
|  | @ -233,6 +233,9 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { | |||
| 	p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) | ||||
| 	p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) | ||||
| 
 | ||||
| 	// Make the OIDC Verifier accessible to all providers that can support it
 | ||||
| 	p.Verifier = o.GetOIDCVerifier() | ||||
| 
 | ||||
| 	p.SetAllowedGroups(o.AllowedGroups) | ||||
| 
 | ||||
| 	provider := providers.New(o.ProviderType, p) | ||||
|  | @ -273,18 +276,14 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { | |||
| 		p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail | ||||
| 		p.UserIDClaim = o.UserIDClaim | ||||
| 		p.GroupsClaim = o.OIDCGroupsClaim | ||||
| 		if o.GetOIDCVerifier() == nil { | ||||
| 		if p.Verifier == nil { | ||||
| 			msgs = append(msgs, "oidc provider requires an oidc issuer URL") | ||||
| 		} else { | ||||
| 			p.Verifier = o.GetOIDCVerifier() | ||||
| 		} | ||||
| 	case *providers.GitLabProvider: | ||||
| 		p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail | ||||
| 		p.Groups = o.GitLabGroup | ||||
| 
 | ||||
| 		if o.GetOIDCVerifier() != nil { | ||||
| 			p.Verifier = o.GetOIDCVerifier() | ||||
| 		} else { | ||||
| 		if p.Verifier == nil { | ||||
| 			// Initialize with default verifier for gitlab.com
 | ||||
| 			ctx := context.Background() | ||||
| 
 | ||||
|  |  | |||
|  | @ -11,7 +11,6 @@ import ( | |||
| 	oidc "github.com/coreos/go-oidc" | ||||
| 	"golang.org/x/oauth2" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" | ||||
|  | @ -23,7 +22,6 @@ const emailClaim = "email" | |||
| type OIDCProvider struct { | ||||
| 	*ProviderData | ||||
| 
 | ||||
| 	Verifier             *oidc.IDTokenVerifier | ||||
| 	AllowUnverifiedEmail bool | ||||
| 	UserIDClaim          string | ||||
| 	GroupsClaim          string | ||||
|  | @ -176,17 +174,12 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok | |||
| 	return newSession, nil | ||||
| } | ||||
| 
 | ||||
| func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string, verify middleware.VerifyFunc) (*sessions.SessionState, error) { | ||||
| 	verifiedToken, err := verify(ctx, token) | ||||
| func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) { | ||||
| 	idToken, err := p.Verifier.Verify(ctx, token) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	idToken, ok := verifiedToken.(*oidc.IDToken) | ||||
| 	if !ok { | ||||
| 		return nil, fmt.Errorf("failed to create IDToken from bearer token: %s", token) | ||||
| 	} | ||||
| 
 | ||||
| 	newSession, err := p.createSessionStateInternal(ctx, idToken, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
|  |  | |||
|  | @ -144,16 +144,17 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider { | |||
| 			Scheme: serverURL.Scheme, | ||||
| 			Host:   serverURL.Host, | ||||
| 			Path:   "/api"}, | ||||
| 		Scope: "openid profile offline_access"} | ||||
| 
 | ||||
| 	p := &OIDCProvider{ | ||||
| 		ProviderData: providerData, | ||||
| 		Scope: "openid profile offline_access", | ||||
| 		Verifier: oidc.NewVerifier( | ||||
| 			"https://issuer.example.com", | ||||
| 			fakeKeySetStub{}, | ||||
| 			&oidc.Config{ClientID: clientID}, | ||||
| 		), | ||||
| 		UserIDClaim: "email", | ||||
| 	} | ||||
| 
 | ||||
| 	p := &OIDCProvider{ | ||||
| 		ProviderData: providerData, | ||||
| 		UserIDClaim:  "email", | ||||
| 	} | ||||
| 
 | ||||
| 	return p | ||||
|  | @ -347,18 +348,7 @@ func TestCreateSessionStateFromBearerToken(t *testing.T) { | |||
| 			rawIDToken, err := newSignedTestIDToken(tc.IDToken) | ||||
| 			assert.NoError(t, err) | ||||
| 
 | ||||
| 			verifyFunc := func(ctx context.Context, token string) (interface{}, error) { | ||||
| 				keyset := fakeKeySetStub{} | ||||
| 				verifier := oidc.NewVerifier("https://issuer.example.com", keyset, | ||||
| 					&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) | ||||
| 
 | ||||
| 				idToken, err := verifier.Verify(ctx, token) | ||||
| 				assert.NoError(t, err) | ||||
| 
 | ||||
| 				return idToken, nil | ||||
| 			} | ||||
| 
 | ||||
| 			ss, err := provider.CreateSessionFromToken(context.Background(), rawIDToken, verifyFunc) | ||||
| 			ss, err := provider.CreateSessionFromToken(context.Background(), rawIDToken) | ||||
| 			assert.NoError(t, err) | ||||
| 
 | ||||
| 			assert.Equal(t, tc.ExpectedUser, ss.User) | ||||
|  |  | |||
|  | @ -5,6 +5,7 @@ import ( | |||
| 	"io/ioutil" | ||||
| 	"net/url" | ||||
| 
 | ||||
| 	"github.com/coreos/go-oidc" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| ) | ||||
| 
 | ||||
|  | @ -26,6 +27,7 @@ type ProviderData struct { | |||
| 	ClientSecretFile string | ||||
| 	Scope            string | ||||
| 	Prompt           string | ||||
| 	Verifier         *oidc.IDTokenVerifier | ||||
| 
 | ||||
| 	// Universal Group authorization data structure
 | ||||
| 	// any provider can set to consume
 | ||||
|  |  | |||
|  | @ -126,6 +126,9 @@ func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.Ses | |||
| 
 | ||||
| // CreateSessionStateFromBearerToken should be implemented to allow providers
 | ||||
| // to convert ID tokens into sessions
 | ||||
| func (p *ProviderData) CreateSessionFromToken(_ context.Context, _ string, _ middleware.VerifyFunc) (*sessions.SessionState, error) { | ||||
| func (p *ProviderData) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) { | ||||
| 	if p.Verifier != nil { | ||||
| 		return middleware.CreateTokenToSessionFunc(p.Verifier.Verify)(ctx, token) | ||||
| 	} | ||||
| 	return nil, ErrNotImplemented | ||||
| } | ||||
|  |  | |||
|  | @ -3,7 +3,6 @@ package providers | |||
| import ( | ||||
| 	"context" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
| ) | ||||
| 
 | ||||
|  | @ -18,7 +17,7 @@ type Provider interface { | |||
| 	ValidateSession(ctx context.Context, s *sessions.SessionState) bool | ||||
| 	GetLoginURL(redirectURI, finalRedirect string) string | ||||
| 	RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) | ||||
| 	CreateSessionFromToken(ctx context.Context, token string, verify middleware.VerifyFunc) (*sessions.SessionState, error) | ||||
| 	CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) | ||||
| } | ||||
| 
 | ||||
| // New provides a new Provider based on the configured provider string
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue