Generalize and extend default CreateSessionFromToken
This commit is contained in:
		
							parent
							
								
									44fa8316a1
								
							
						
					
					
						commit
						22f60e9b63
					
				
							
								
								
									
										112
									
								
								oauthproxy.go
								
								
								
								
							
							
						
						
									
										112
									
								
								oauthproxy.go
								
								
								
								
							|  | @ -13,7 +13,6 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/coreos/go-oidc" |  | ||||||
| 	"github.com/justinas/alice" | 	"github.com/justinas/alice" | ||||||
| 	ipapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/ip" | 	ipapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/ip" | ||||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
|  | @ -78,36 +77,34 @@ type OAuthProxy struct { | ||||||
| 	AuthOnlyPath      string | 	AuthOnlyPath      string | ||||||
| 	UserInfoPath      string | 	UserInfoPath      string | ||||||
| 
 | 
 | ||||||
| 	allowedRoutes           []allowedRoute | 	allowedRoutes        []allowedRoute | ||||||
| 	redirectURL             *url.URL // the url to receive requests at
 | 	redirectURL          *url.URL // the url to receive requests at
 | ||||||
| 	whitelistDomains        []string | 	whitelistDomains     []string | ||||||
| 	provider                providers.Provider | 	provider             providers.Provider | ||||||
| 	providerNameOverride    string | 	providerNameOverride string | ||||||
| 	sessionStore            sessionsapi.SessionStore | 	sessionStore         sessionsapi.SessionStore | ||||||
| 	ProxyPrefix             string | 	ProxyPrefix          string | ||||||
| 	SignInMessage           string | 	SignInMessage        string | ||||||
| 	basicAuthValidator      basic.Validator | 	basicAuthValidator   basic.Validator | ||||||
| 	displayHtpasswdForm     bool | 	displayHtpasswdForm  bool | ||||||
| 	serveMux                http.Handler | 	serveMux             http.Handler | ||||||
| 	SetXAuthRequest         bool | 	SetXAuthRequest      bool | ||||||
| 	PassBasicAuth           bool | 	PassBasicAuth        bool | ||||||
| 	SetBasicAuth            bool | 	SetBasicAuth         bool | ||||||
| 	SkipProviderButton      bool | 	SkipProviderButton   bool | ||||||
| 	PassUserHeaders         bool | 	PassUserHeaders      bool | ||||||
| 	BasicAuthPassword       string | 	BasicAuthPassword    string | ||||||
| 	PassAccessToken         bool | 	PassAccessToken      bool | ||||||
| 	SetAuthorization        bool | 	SetAuthorization     bool | ||||||
| 	PassAuthorization       bool | 	PassAuthorization    bool | ||||||
| 	PreferEmailToUser       bool | 	PreferEmailToUser    bool | ||||||
| 	skipAuthPreflight       bool | 	skipAuthPreflight    bool | ||||||
| 	skipJwtBearerTokens     bool | 	skipJwtBearerTokens  bool | ||||||
| 	mainJwtBearerVerifier   *oidc.IDTokenVerifier | 	templates            *template.Template | ||||||
| 	extraJwtBearerVerifiers []*oidc.IDTokenVerifier | 	realClientIPParser   ipapi.RealClientIPParser | ||||||
| 	templates               *template.Template | 	trustedIPs           *ip.NetSet | ||||||
| 	realClientIPParser      ipapi.RealClientIPParser | 	Banner               string | ||||||
| 	trustedIPs              *ip.NetSet | 	Footer               string | ||||||
| 	Banner                  string |  | ||||||
| 	Footer                  string |  | ||||||
| 
 | 
 | ||||||
| 	sessionChain alice.Chain | 	sessionChain alice.Chain | ||||||
| 	headersChain alice.Chain | 	headersChain alice.Chain | ||||||
|  | @ -202,25 +199,23 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		AuthOnlyPath:      fmt.Sprintf("%s/auth", opts.ProxyPrefix), | 		AuthOnlyPath:      fmt.Sprintf("%s/auth", opts.ProxyPrefix), | ||||||
| 		UserInfoPath:      fmt.Sprintf("%s/userinfo", opts.ProxyPrefix), | 		UserInfoPath:      fmt.Sprintf("%s/userinfo", opts.ProxyPrefix), | ||||||
| 
 | 
 | ||||||
| 		ProxyPrefix:             opts.ProxyPrefix, | 		ProxyPrefix:          opts.ProxyPrefix, | ||||||
| 		provider:                opts.GetProvider(), | 		provider:             opts.GetProvider(), | ||||||
| 		providerNameOverride:    opts.ProviderName, | 		providerNameOverride: opts.ProviderName, | ||||||
| 		sessionStore:            sessionStore, | 		sessionStore:         sessionStore, | ||||||
| 		serveMux:                upstreamProxy, | 		serveMux:             upstreamProxy, | ||||||
| 		redirectURL:             redirectURL, | 		redirectURL:          redirectURL, | ||||||
| 		allowedRoutes:           allowedRoutes, | 		allowedRoutes:        allowedRoutes, | ||||||
| 		whitelistDomains:        opts.WhitelistDomains, | 		whitelistDomains:     opts.WhitelistDomains, | ||||||
| 		skipAuthPreflight:       opts.SkipAuthPreflight, | 		skipAuthPreflight:    opts.SkipAuthPreflight, | ||||||
| 		skipJwtBearerTokens:     opts.SkipJwtBearerTokens, | 		skipJwtBearerTokens:  opts.SkipJwtBearerTokens, | ||||||
| 		mainJwtBearerVerifier:   opts.GetOIDCVerifier(), | 		realClientIPParser:   opts.GetRealClientIPParser(), | ||||||
| 		extraJwtBearerVerifiers: opts.GetJWTBearerVerifiers(), | 		SkipProviderButton:   opts.SkipProviderButton, | ||||||
| 		realClientIPParser:      opts.GetRealClientIPParser(), | 		templates:            templates, | ||||||
| 		SkipProviderButton:      opts.SkipProviderButton, | 		trustedIPs:           trustedIPs, | ||||||
| 		templates:               templates, | 		Banner:               opts.Banner, | ||||||
| 		trustedIPs:              trustedIPs, | 		Footer:               opts.Footer, | ||||||
| 		Banner:                  opts.Banner, | 		SignInMessage:        buildSignInMessage(opts), | ||||||
| 		Footer:                  opts.Footer, |  | ||||||
| 		SignInMessage:           buildSignInMessage(opts), |  | ||||||
| 
 | 
 | ||||||
| 		basicAuthValidator:  basicAuthValidator, | 		basicAuthValidator:  basicAuthValidator, | ||||||
| 		displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm, | 		displayHtpasswdForm: basicAuthValidator != nil && opts.DisplayHtpasswdForm, | ||||||
|  | @ -266,22 +261,13 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt | ||||||
| 	chain := alice.New() | 	chain := alice.New() | ||||||
| 
 | 
 | ||||||
| 	if opts.SkipJwtBearerTokens { | 	if opts.SkipJwtBearerTokens { | ||||||
| 		sessionLoaders := []middlewareapi.TokenToSessionLoader{} | 		sessionLoaders := []middlewareapi.TokenToSessionFunc{ | ||||||
| 		if opts.GetOIDCVerifier() != nil { | 			opts.GetProvider().CreateSessionFromToken, | ||||||
| 			sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{ |  | ||||||
| 				Verifier: func(ctx context.Context, token string) (interface{}, error) { |  | ||||||
| 					return opts.GetOIDCVerifier().Verify(ctx, token) |  | ||||||
| 				}, |  | ||||||
| 				TokenToSession: opts.GetProvider().CreateSessionFromToken, |  | ||||||
| 			}) |  | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		for _, verifier := range opts.GetJWTBearerVerifiers() { | 		for _, verifier := range opts.GetJWTBearerVerifiers() { | ||||||
| 			sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{ | 			sessionLoaders = append(sessionLoaders, | ||||||
| 				Verifier: func(ctx context.Context, token string) (interface{}, error) { | 				middlewareapi.CreateTokenToSessionFunc(verifier.Verify)) | ||||||
| 					return verifier.Verify(ctx, token) |  | ||||||
| 				}, |  | ||||||
| 			}) |  | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		chain = chain.Append(middleware.NewJwtSessionLoader(sessionLoaders)) | 		chain = chain.Append(middleware.NewJwtSessionLoader(sessionLoaders)) | ||||||
|  |  | ||||||
|  | @ -2,25 +2,57 @@ package middleware | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"fmt" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/coreos/go-oidc" | ||||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // TokenToSessionFunc takes a rawIDToken and an idToken and converts it into a
 | // TokenToSessionFunc takes a raw ID Token and converts it into a SessionState.
 | ||||||
| // SessionState.
 | type TokenToSessionFunc func(ctx context.Context, token string) (*sessionsapi.SessionState, error) | ||||||
| type TokenToSessionFunc func(ctx context.Context, token string, verify VerifyFunc) (*sessionsapi.SessionState, error) |  | ||||||
| 
 | 
 | ||||||
| // VerifyFunc takes a raw bearer token and verifies it
 | // VerifyFunc takes a raw bearer token and verifies it returning the converted
 | ||||||
| type VerifyFunc func(ctx context.Context, token string) (interface{}, error) | // oidc.IDToken representation of the token.
 | ||||||
|  | type VerifyFunc func(ctx context.Context, token string) (*oidc.IDToken, error) | ||||||
| 
 | 
 | ||||||
| // TokenToSessionLoader pairs a token verifier with the correct converter function
 | // CreateTokenToSessionFunc provides a handler that is a default implementation
 | ||||||
| // to convert the ID Token to a SessionState.
 | // for converting a JWT into a session.
 | ||||||
| type TokenToSessionLoader struct { | func CreateTokenToSessionFunc(verify VerifyFunc) TokenToSessionFunc { | ||||||
| 	// Verifier is used to verify that the ID Token was signed by the claimed issuer
 | 	return func(ctx context.Context, token string) (*sessionsapi.SessionState, error) { | ||||||
| 	// and that the token has not been tampered with.
 | 		var claims struct { | ||||||
| 	Verifier VerifyFunc | 			Subject           string `json:"sub"` | ||||||
|  | 			Email             string `json:"email"` | ||||||
|  | 			Verified          *bool  `json:"email_verified"` | ||||||
|  | 			PreferredUsername string `json:"preferred_username"` | ||||||
|  | 		} | ||||||
| 
 | 
 | ||||||
| 	// TokenToSession converts a raw bearer token to a SessionState.
 | 		idToken, err := verify(ctx, token) | ||||||
| 	// (Optional) If not set a default basic implementation is used.
 | 		if err != nil { | ||||||
| 	TokenToSession TokenToSessionFunc | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if err := idToken.Claims(&claims); err != nil { | ||||||
|  | 			return nil, fmt.Errorf("failed to parse bearer token claims: %v", err) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if claims.Email == "" { | ||||||
|  | 			claims.Email = claims.Subject | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		if claims.Verified != nil && !*claims.Verified { | ||||||
|  | 			return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		newSession := &sessionsapi.SessionState{ | ||||||
|  | 			Email:             claims.Email, | ||||||
|  | 			User:              claims.Subject, | ||||||
|  | 			PreferredUsername: claims.PreferredUsername, | ||||||
|  | 			AccessToken:       token, | ||||||
|  | 			IDToken:           token, | ||||||
|  | 			RefreshToken:      "", | ||||||
|  | 			ExpiresOn:         &idToken.Expiry, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		return newSession, nil | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,12 +1,10 @@ | ||||||
| package middleware | package middleware | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" |  | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"regexp" | 	"regexp" | ||||||
| 
 | 
 | ||||||
| 	"github.com/coreos/go-oidc" |  | ||||||
| 	"github.com/justinas/alice" | 	"github.com/justinas/alice" | ||||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
|  | @ -16,16 +14,7 @@ import ( | ||||||
| 
 | 
 | ||||||
| const jwtRegexFormat = `^ey[IJ][a-zA-Z0-9_-]*\.ey[IJ][a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$` | const jwtRegexFormat = `^ey[IJ][a-zA-Z0-9_-]*\.ey[IJ][a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$` | ||||||
| 
 | 
 | ||||||
| func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionLoader) alice.Constructor { | func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionFunc) alice.Constructor { | ||||||
| 	for i, loader := range sessionLoaders { |  | ||||||
| 		if loader.TokenToSession == nil { |  | ||||||
| 			sessionLoaders[i] = middlewareapi.TokenToSessionLoader{ |  | ||||||
| 				Verifier:       loader.Verifier, |  | ||||||
| 				TokenToSession: createSessionFromToken, |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	js := &jwtSessionLoader{ | 	js := &jwtSessionLoader{ | ||||||
| 		jwtRegex:       regexp.MustCompile(jwtRegexFormat), | 		jwtRegex:       regexp.MustCompile(jwtRegexFormat), | ||||||
| 		sessionLoaders: sessionLoaders, | 		sessionLoaders: sessionLoaders, | ||||||
|  | @ -37,7 +26,7 @@ func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionLoader) al | ||||||
| // Authorization headers.
 | // Authorization headers.
 | ||||||
| type jwtSessionLoader struct { | type jwtSessionLoader struct { | ||||||
| 	jwtRegex       *regexp.Regexp | 	jwtRegex       *regexp.Regexp | ||||||
| 	sessionLoaders []middlewareapi.TokenToSessionLoader | 	sessionLoaders []middlewareapi.TokenToSessionFunc | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // loadSession attempts to load a session from a JWT stored in an Authorization
 | // loadSession attempts to load a session from a JWT stored in an Authorization
 | ||||||
|  | @ -83,7 +72,7 @@ func (j *jwtSessionLoader) getJwtSession(req *http.Request) (*sessionsapi.Sessio | ||||||
| 
 | 
 | ||||||
| 	errs := []error{fmt.Errorf("unable to verify jwt token: %q", req.Header.Get("Authorization"))} | 	errs := []error{fmt.Errorf("unable to verify jwt token: %q", req.Header.Get("Authorization"))} | ||||||
| 	for _, loader := range j.sessionLoaders { | 	for _, loader := range j.sessionLoaders { | ||||||
| 		session, err := loader.TokenToSession(req.Context(), token, loader.Verifier) | 		session, err := loader(req.Context(), token) | ||||||
| 		if err == nil { | 		if err == nil { | ||||||
| 			return session, nil | 			return session, nil | ||||||
| 		} else { | 		} else { | ||||||
|  | @ -135,48 +124,3 @@ func (j *jwtSessionLoader) getBasicToken(token string) (string, error) { | ||||||
| 
 | 
 | ||||||
| 	return "", fmt.Errorf("invalid basic auth token found in authorization header") | 	return "", fmt.Errorf("invalid basic auth token found in authorization header") | ||||||
| } | } | ||||||
| 
 |  | ||||||
| // createSessionFromToken is a default implementation for converting
 |  | ||||||
| // a JWT into a session state.
 |  | ||||||
| func createSessionFromToken(ctx context.Context, token string, verify middlewareapi.VerifyFunc) (*sessionsapi.SessionState, error) { |  | ||||||
| 	var claims struct { |  | ||||||
| 		Subject           string `json:"sub"` |  | ||||||
| 		Email             string `json:"email"` |  | ||||||
| 		Verified          *bool  `json:"email_verified"` |  | ||||||
| 		PreferredUsername string `json:"preferred_username"` |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	verifiedToken, err := verify(ctx, token) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	idToken, ok := verifiedToken.(*oidc.IDToken) |  | ||||||
| 	if !ok { |  | ||||||
| 		return nil, fmt.Errorf("failed to create IDToken from bearer token: %s", token) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if err := idToken.Claims(&claims); err != nil { |  | ||||||
| 		return nil, fmt.Errorf("failed to parse bearer token claims: %v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if claims.Email == "" { |  | ||||||
| 		claims.Email = claims.Subject |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if claims.Verified != nil && !*claims.Verified { |  | ||||||
| 		return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	newSession := &sessionsapi.SessionState{ |  | ||||||
| 		Email:             claims.Email, |  | ||||||
| 		User:              claims.Subject, |  | ||||||
| 		PreferredUsername: claims.PreferredUsername, |  | ||||||
| 		AccessToken:       token, |  | ||||||
| 		IDToken:           token, |  | ||||||
| 		RefreshToken:      "", |  | ||||||
| 		ExpiresOn:         &idToken.Expiry, |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return newSession, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -26,7 +26,7 @@ import ( | ||||||
| type noOpKeySet struct { | type noOpKeySet struct { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (noOpKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { | func (noOpKeySet) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) { | ||||||
| 	splitStrings := strings.Split(jwt, ".") | 	splitStrings := strings.Split(jwt, ".") | ||||||
| 	payloadString := splitStrings[1] | 	payloadString := splitStrings[1] | ||||||
| 	return base64.RawURLEncoding.DecodeString(payloadString) | 	return base64.RawURLEncoding.DecodeString(payloadString) | ||||||
|  | @ -78,16 +78,14 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | ||||||
| 		const nonVerifiedToken = validToken | 		const nonVerifiedToken = validToken | ||||||
| 
 | 
 | ||||||
| 		BeforeEach(func() { | 		BeforeEach(func() { | ||||||
| 			verifier = func(ctx context.Context, token string) (interface{}, error) { | 			verifier = oidc.NewVerifier( | ||||||
| 				return oidc.NewVerifier( | 				"https://issuer.example.com", | ||||||
| 					"https://issuer.example.com", | 				noOpKeySet{}, | ||||||
| 					noOpKeySet{}, | 				&oidc.Config{ | ||||||
| 					&oidc.Config{ | 					ClientID:        "https://test.myapp.com", | ||||||
| 						ClientID:        "https://test.myapp.com", | 					SkipExpiryCheck: true, | ||||||
| 						SkipExpiryCheck: true, | 				}, | ||||||
| 					}, | 			).Verify | ||||||
| 				).Verify(ctx, token) |  | ||||||
| 			} |  | ||||||
| 		}) | 		}) | ||||||
| 
 | 
 | ||||||
| 		type jwtSessionLoaderTableInput struct { | 		type jwtSessionLoaderTableInput struct { | ||||||
|  | @ -110,10 +108,8 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | ||||||
| 
 | 
 | ||||||
| 				rw := httptest.NewRecorder() | 				rw := httptest.NewRecorder() | ||||||
| 
 | 
 | ||||||
| 				sessionLoaders := []middlewareapi.TokenToSessionLoader{ | 				sessionLoaders := []middlewareapi.TokenToSessionFunc{ | ||||||
| 					{ | 					middlewareapi.CreateTokenToSessionFunc(verifier), | ||||||
| 						Verifier: verifier, |  | ||||||
| 					}, |  | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				// Create the handler with a next handler that will capture the session
 | 				// Create the handler with a next handler that will capture the session
 | ||||||
|  | @ -175,24 +171,19 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | ||||||
| 		const nonVerifiedToken = validToken | 		const nonVerifiedToken = validToken | ||||||
| 
 | 
 | ||||||
| 		BeforeEach(func() { | 		BeforeEach(func() { | ||||||
| 			verifier := func(ctx context.Context, token string) (interface{}, error) { | 			verifier := oidc.NewVerifier( | ||||||
| 				return oidc.NewVerifier( | 				"https://issuer.example.com", | ||||||
| 					"https://issuer.example.com", | 				noOpKeySet{}, | ||||||
| 					noOpKeySet{}, | 				&oidc.Config{ | ||||||
| 					&oidc.Config{ | 					ClientID:        "https://test.myapp.com", | ||||||
| 						ClientID:        "https://test.myapp.com", | 					SkipExpiryCheck: true, | ||||||
| 						SkipExpiryCheck: true, | 				}, | ||||||
| 					}, | 			).Verify | ||||||
| 				).Verify(ctx, token) |  | ||||||
| 			} |  | ||||||
| 
 | 
 | ||||||
| 			j = &jwtSessionLoader{ | 			j = &jwtSessionLoader{ | ||||||
| 				jwtRegex: regexp.MustCompile(jwtRegexFormat), | 				jwtRegex: regexp.MustCompile(jwtRegexFormat), | ||||||
| 				sessionLoaders: []middlewareapi.TokenToSessionLoader{ | 				sessionLoaders: []middlewareapi.TokenToSessionFunc{ | ||||||
| 					{ | 					middlewareapi.CreateTokenToSessionFunc(verifier), | ||||||
| 						Verifier:       verifier, |  | ||||||
| 						TokenToSession: createSessionFromToken, |  | ||||||
| 					}, |  | ||||||
| 				}, | 				}, | ||||||
| 			} | 			} | ||||||
| 		}) | 		}) | ||||||
|  | @ -402,7 +393,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | ||||||
| 		) | 		) | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	Context("createSessionFromToken", func() { | 	Context("CreateTokenToSessionFunc", func() { | ||||||
| 		ctx := context.Background() | 		ctx := context.Background() | ||||||
| 		expiresFuture := time.Now().Add(time.Duration(5) * time.Minute) | 		expiresFuture := time.Now().Add(time.Duration(5) * time.Minute) | ||||||
| 		verified := true | 		verified := true | ||||||
|  | @ -414,7 +405,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | ||||||
| 			jwt.StandardClaims | 			jwt.StandardClaims | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		type createSessionStateTableInput struct { | 		type tokenToSessionTableInput struct { | ||||||
| 			idToken         idTokenClaims | 			idToken         idTokenClaims | ||||||
| 			expectedErr     error | 			expectedErr     error | ||||||
| 			expectedUser    string | 			expectedUser    string | ||||||
|  | @ -423,8 +414,8 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		DescribeTable("when creating a session from an IDToken", | 		DescribeTable("when creating a session from an IDToken", | ||||||
| 			func(in createSessionStateTableInput) { | 			func(in tokenToSessionTableInput) { | ||||||
| 				verifier := func(ctx context.Context, token string) (interface{}, error) { | 				verifier := func(ctx context.Context, token string) (*oidc.IDToken, error) { | ||||||
| 					oidcVerifier := oidc.NewVerifier( | 					oidcVerifier := oidc.NewVerifier( | ||||||
| 						"https://issuer.example.com", | 						"https://issuer.example.com", | ||||||
| 						noOpKeySet{}, | 						noOpKeySet{}, | ||||||
|  | @ -443,7 +434,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | ||||||
| 				rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, in.idToken).SignedString(key) | 				rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, in.idToken).SignedString(key) | ||||||
| 				Expect(err).ToNot(HaveOccurred()) | 				Expect(err).ToNot(HaveOccurred()) | ||||||
| 
 | 
 | ||||||
| 				session, err := createSessionFromToken(ctx, rawIDToken, verifier) | 				session, err := middlewareapi.CreateTokenToSessionFunc(verifier)(ctx, rawIDToken) | ||||||
| 				if in.expectedErr != nil { | 				if in.expectedErr != nil { | ||||||
| 					Expect(err).To(MatchError(in.expectedErr)) | 					Expect(err).To(MatchError(in.expectedErr)) | ||||||
| 					Expect(session).To(BeNil()) | 					Expect(session).To(BeNil()) | ||||||
|  | @ -459,7 +450,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | ||||||
| 				Expect(session.RefreshToken).To(BeEmpty()) | 				Expect(session.RefreshToken).To(BeEmpty()) | ||||||
| 				Expect(session.PreferredUsername).To(BeEmpty()) | 				Expect(session.PreferredUsername).To(BeEmpty()) | ||||||
| 			}, | 			}, | ||||||
| 			Entry("with no email", createSessionStateTableInput{ | 			Entry("with no email", tokenToSessionTableInput{ | ||||||
| 				idToken: idTokenClaims{ | 				idToken: idTokenClaims{ | ||||||
| 					StandardClaims: jwt.StandardClaims{ | 					StandardClaims: jwt.StandardClaims{ | ||||||
| 						Audience:  "asdf1234", | 						Audience:  "asdf1234", | ||||||
|  | @ -476,7 +467,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | ||||||
| 				expectedEmail:   "123456789", | 				expectedEmail:   "123456789", | ||||||
| 				expectedExpires: &expiresFuture, | 				expectedExpires: &expiresFuture, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("with a verified email", createSessionStateTableInput{ | 			Entry("with a verified email", tokenToSessionTableInput{ | ||||||
| 				idToken: idTokenClaims{ | 				idToken: idTokenClaims{ | ||||||
| 					StandardClaims: jwt.StandardClaims{ | 					StandardClaims: jwt.StandardClaims{ | ||||||
| 						Audience:  "asdf1234", | 						Audience:  "asdf1234", | ||||||
|  | @ -495,7 +486,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | ||||||
| 				expectedEmail:   "foo@example.com", | 				expectedEmail:   "foo@example.com", | ||||||
| 				expectedExpires: &expiresFuture, | 				expectedExpires: &expiresFuture, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("with a non-verified email", createSessionStateTableInput{ | 			Entry("with a non-verified email", tokenToSessionTableInput{ | ||||||
| 				idToken: idTokenClaims{ | 				idToken: idTokenClaims{ | ||||||
| 					StandardClaims: jwt.StandardClaims{ | 					StandardClaims: jwt.StandardClaims{ | ||||||
| 						Audience:  "asdf1234", | 						Audience:  "asdf1234", | ||||||
|  |  | ||||||
|  | @ -233,6 +233,9 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { | ||||||
| 	p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) | 	p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) | ||||||
| 	p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) | 	p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) | ||||||
| 
 | 
 | ||||||
|  | 	// Make the OIDC Verifier accessible to all providers that can support it
 | ||||||
|  | 	p.Verifier = o.GetOIDCVerifier() | ||||||
|  | 
 | ||||||
| 	p.SetAllowedGroups(o.AllowedGroups) | 	p.SetAllowedGroups(o.AllowedGroups) | ||||||
| 
 | 
 | ||||||
| 	provider := providers.New(o.ProviderType, p) | 	provider := providers.New(o.ProviderType, p) | ||||||
|  | @ -273,18 +276,14 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { | ||||||
| 		p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail | 		p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail | ||||||
| 		p.UserIDClaim = o.UserIDClaim | 		p.UserIDClaim = o.UserIDClaim | ||||||
| 		p.GroupsClaim = o.OIDCGroupsClaim | 		p.GroupsClaim = o.OIDCGroupsClaim | ||||||
| 		if o.GetOIDCVerifier() == nil { | 		if p.Verifier == nil { | ||||||
| 			msgs = append(msgs, "oidc provider requires an oidc issuer URL") | 			msgs = append(msgs, "oidc provider requires an oidc issuer URL") | ||||||
| 		} else { |  | ||||||
| 			p.Verifier = o.GetOIDCVerifier() |  | ||||||
| 		} | 		} | ||||||
| 	case *providers.GitLabProvider: | 	case *providers.GitLabProvider: | ||||||
| 		p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail | 		p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail | ||||||
| 		p.Groups = o.GitLabGroup | 		p.Groups = o.GitLabGroup | ||||||
| 
 | 
 | ||||||
| 		if o.GetOIDCVerifier() != nil { | 		if p.Verifier == nil { | ||||||
| 			p.Verifier = o.GetOIDCVerifier() |  | ||||||
| 		} else { |  | ||||||
| 			// Initialize with default verifier for gitlab.com
 | 			// Initialize with default verifier for gitlab.com
 | ||||||
| 			ctx := context.Background() | 			ctx := context.Background() | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -11,7 +11,6 @@ import ( | ||||||
| 	oidc "github.com/coreos/go-oidc" | 	oidc "github.com/coreos/go-oidc" | ||||||
| 	"golang.org/x/oauth2" | 	"golang.org/x/oauth2" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" | ||||||
|  | @ -23,7 +22,6 @@ const emailClaim = "email" | ||||||
| type OIDCProvider struct { | type OIDCProvider struct { | ||||||
| 	*ProviderData | 	*ProviderData | ||||||
| 
 | 
 | ||||||
| 	Verifier             *oidc.IDTokenVerifier |  | ||||||
| 	AllowUnverifiedEmail bool | 	AllowUnverifiedEmail bool | ||||||
| 	UserIDClaim          string | 	UserIDClaim          string | ||||||
| 	GroupsClaim          string | 	GroupsClaim          string | ||||||
|  | @ -176,17 +174,12 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok | ||||||
| 	return newSession, nil | 	return newSession, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string, verify middleware.VerifyFunc) (*sessions.SessionState, error) { | func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) { | ||||||
| 	verifiedToken, err := verify(ctx, token) | 	idToken, err := p.Verifier.Verify(ctx, token) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	idToken, ok := verifiedToken.(*oidc.IDToken) |  | ||||||
| 	if !ok { |  | ||||||
| 		return nil, fmt.Errorf("failed to create IDToken from bearer token: %s", token) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	newSession, err := p.createSessionStateInternal(ctx, idToken, nil) | 	newSession, err := p.createSessionStateInternal(ctx, idToken, nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
|  |  | ||||||
|  | @ -144,16 +144,17 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider { | ||||||
| 			Scheme: serverURL.Scheme, | 			Scheme: serverURL.Scheme, | ||||||
| 			Host:   serverURL.Host, | 			Host:   serverURL.Host, | ||||||
| 			Path:   "/api"}, | 			Path:   "/api"}, | ||||||
| 		Scope: "openid profile offline_access"} | 		Scope: "openid profile offline_access", | ||||||
| 
 |  | ||||||
| 	p := &OIDCProvider{ |  | ||||||
| 		ProviderData: providerData, |  | ||||||
| 		Verifier: oidc.NewVerifier( | 		Verifier: oidc.NewVerifier( | ||||||
| 			"https://issuer.example.com", | 			"https://issuer.example.com", | ||||||
| 			fakeKeySetStub{}, | 			fakeKeySetStub{}, | ||||||
| 			&oidc.Config{ClientID: clientID}, | 			&oidc.Config{ClientID: clientID}, | ||||||
| 		), | 		), | ||||||
| 		UserIDClaim: "email", | 	} | ||||||
|  | 
 | ||||||
|  | 	p := &OIDCProvider{ | ||||||
|  | 		ProviderData: providerData, | ||||||
|  | 		UserIDClaim:  "email", | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return p | 	return p | ||||||
|  | @ -347,18 +348,7 @@ func TestCreateSessionStateFromBearerToken(t *testing.T) { | ||||||
| 			rawIDToken, err := newSignedTestIDToken(tc.IDToken) | 			rawIDToken, err := newSignedTestIDToken(tc.IDToken) | ||||||
| 			assert.NoError(t, err) | 			assert.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 			verifyFunc := func(ctx context.Context, token string) (interface{}, error) { | 			ss, err := provider.CreateSessionFromToken(context.Background(), rawIDToken) | ||||||
| 				keyset := fakeKeySetStub{} |  | ||||||
| 				verifier := oidc.NewVerifier("https://issuer.example.com", keyset, |  | ||||||
| 					&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) |  | ||||||
| 
 |  | ||||||
| 				idToken, err := verifier.Verify(ctx, token) |  | ||||||
| 				assert.NoError(t, err) |  | ||||||
| 
 |  | ||||||
| 				return idToken, nil |  | ||||||
| 			} |  | ||||||
| 
 |  | ||||||
| 			ss, err := provider.CreateSessionFromToken(context.Background(), rawIDToken, verifyFunc) |  | ||||||
| 			assert.NoError(t, err) | 			assert.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 			assert.Equal(t, tc.ExpectedUser, ss.User) | 			assert.Equal(t, tc.ExpectedUser, ss.User) | ||||||
|  |  | ||||||
|  | @ -5,6 +5,7 @@ import ( | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/coreos/go-oidc" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -26,6 +27,7 @@ type ProviderData struct { | ||||||
| 	ClientSecretFile string | 	ClientSecretFile string | ||||||
| 	Scope            string | 	Scope            string | ||||||
| 	Prompt           string | 	Prompt           string | ||||||
|  | 	Verifier         *oidc.IDTokenVerifier | ||||||
| 
 | 
 | ||||||
| 	// Universal Group authorization data structure
 | 	// Universal Group authorization data structure
 | ||||||
| 	// any provider can set to consume
 | 	// any provider can set to consume
 | ||||||
|  |  | ||||||
|  | @ -126,6 +126,9 @@ func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.Ses | ||||||
| 
 | 
 | ||||||
| // CreateSessionStateFromBearerToken should be implemented to allow providers
 | // CreateSessionStateFromBearerToken should be implemented to allow providers
 | ||||||
| // to convert ID tokens into sessions
 | // to convert ID tokens into sessions
 | ||||||
| func (p *ProviderData) CreateSessionFromToken(_ context.Context, _ string, _ middleware.VerifyFunc) (*sessions.SessionState, error) { | func (p *ProviderData) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) { | ||||||
|  | 	if p.Verifier != nil { | ||||||
|  | 		return middleware.CreateTokenToSessionFunc(p.Verifier.Verify)(ctx, token) | ||||||
|  | 	} | ||||||
| 	return nil, ErrNotImplemented | 	return nil, ErrNotImplemented | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -3,7 +3,6 @@ package providers | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -18,7 +17,7 @@ type Provider interface { | ||||||
| 	ValidateSession(ctx context.Context, s *sessions.SessionState) bool | 	ValidateSession(ctx context.Context, s *sessions.SessionState) bool | ||||||
| 	GetLoginURL(redirectURI, finalRedirect string) string | 	GetLoginURL(redirectURI, finalRedirect string) string | ||||||
| 	RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) | 	RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) | ||||||
| 	CreateSessionFromToken(ctx context.Context, token string, verify middleware.VerifyFunc) (*sessions.SessionState, error) | 	CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // New provides a new Provider based on the configured provider string
 | // New provides a new Provider based on the configured provider string
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue