diff --git a/oauthproxy.go b/oauthproxy.go index 02891612..d546f005 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -269,14 +269,18 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt sessionLoaders := []middlewareapi.TokenToSessionLoader{} if opts.GetOIDCVerifier() != nil { sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{ - Verifier: opts.GetOIDCVerifier(), - TokenToSession: opts.GetProvider().CreateSessionFromBearer, + Verifier: func(ctx context.Context, token string) (interface{}, error) { + return opts.GetOIDCVerifier().Verify(ctx, token) + }, + TokenToSession: opts.GetProvider().CreateSessionFromToken, }) } for _, verifier := range opts.GetJWTBearerVerifiers() { sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{ - Verifier: verifier, + Verifier: func(ctx context.Context, token string) (interface{}, error) { + return verifier.Verify(ctx, token) + }, }) } diff --git a/pkg/apis/middleware/session.go b/pkg/apis/middleware/session.go index 95a76fba..a8a3bbea 100644 --- a/pkg/apis/middleware/session.go +++ b/pkg/apis/middleware/session.go @@ -3,22 +3,24 @@ package middleware import ( "context" - "github.com/coreos/go-oidc" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" ) // TokenToSessionFunc takes a rawIDToken and an idToken and converts it into a // SessionState. -type TokenToSessionFunc func(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessionsapi.SessionState, error) +type TokenToSessionFunc func(ctx context.Context, token string, verify VerifyFunc) (*sessionsapi.SessionState, error) + +// VerifyFunc takes a raw bearer token and verifies it +type VerifyFunc func(ctx context.Context, token string) (interface{}, error) // TokenToSessionLoader pairs a token verifier with the correct converter function // to convert the ID Token to a SessionState. type TokenToSessionLoader struct { - // Verfier is used to verify that the ID Token was signed by the claimed issuer + // Verifier is used to verify that the ID Token was signed by the claimed issuer // and that the token has not been tampered with. - Verifier *oidc.IDTokenVerifier + Verifier VerifyFunc - // TokenToSession converts a rawIDToken and an idToken to a SessionState. + // TokenToSession converts a raw bearer token to a SessionState. // (Optional) If not set a default basic implementation is used. TokenToSession TokenToSessionFunc } diff --git a/pkg/middleware/jwt_session.go b/pkg/middleware/jwt_session.go index 024a45ac..5e99e0df 100644 --- a/pkg/middleware/jwt_session.go +++ b/pkg/middleware/jwt_session.go @@ -13,14 +13,14 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" ) -const jwtRegexFormat = `^eyJ[a-zA-Z0-9_-]*\.eyJ[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$` +const jwtRegexFormat = `^ey[IJ][a-zA-Z0-9_-]*\.ey[IJ][a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$` func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionLoader) alice.Constructor { for i, loader := range sessionLoaders { if loader.TokenToSession == nil { sessionLoaders[i] = middlewareapi.TokenToSessionLoader{ Verifier: loader.Verifier, - TokenToSession: createSessionStateFromBearerToken, + TokenToSession: createSessionFromToken, } } } @@ -75,24 +75,24 @@ func (j *jwtSessionLoader) getJwtSession(req *http.Request) (*sessionsapi.Sessio return nil, nil } - rawBearerToken, err := j.findBearerTokenFromHeader(auth) + token, err := j.findTokenFromHeader(auth) if err != nil { return nil, err } for _, loader := range j.sessionLoaders { - bearerToken, err := loader.Verifier.Verify(req.Context(), rawBearerToken) + session, err := loader.TokenToSession(req.Context(), token, loader.Verifier) if err == nil { - // The token was verified, convert it to a session - return loader.TokenToSession(req.Context(), rawBearerToken, bearerToken) + return session, nil } } + // TODO (@NickMeves) Aggregate error logs in the chain return nil, fmt.Errorf("unable to verify jwt token: %q", req.Header.Get("Authorization")) } -// findBearerTokenFromHeader finds a valid JWT token from the Authorization header of a given request. -func (j *jwtSessionLoader) findBearerTokenFromHeader(header string) (string, error) { +// findTokenFromHeader finds a valid JWT token from the Authorization header of a given request. +func (j *jwtSessionLoader) findTokenFromHeader(header string) (string, error) { tokenType, token, err := splitAuthHeader(header) if err != nil { return "", err @@ -133,9 +133,9 @@ func (j *jwtSessionLoader) getBasicToken(token string) (string, error) { return "", fmt.Errorf("invalid basic auth token found in authorization header") } -// createSessionStateFromBearerToken is a default implementation for converting +// createSessionFromToken is a default implementation for converting // a JWT into a session state. -func createSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessionsapi.SessionState, error) { +func createSessionFromToken(ctx context.Context, token string, verify middlewareapi.VerifyFunc) (*sessionsapi.SessionState, error) { var claims struct { Subject string `json:"sub"` Email string `json:"email"` @@ -143,6 +143,16 @@ func createSessionStateFromBearerToken(ctx context.Context, rawIDToken string, i PreferredUsername string `json:"preferred_username"` } + verifiedToken, err := verify(ctx, token) + if err != nil { + return nil, err + } + + idToken, ok := verifiedToken.(*oidc.IDToken) + if !ok { + return nil, fmt.Errorf("failed to create IDToken from bearer token: %s", token) + } + if err := idToken.Claims(&claims); err != nil { return nil, fmt.Errorf("failed to parse bearer token claims: %v", err) } @@ -159,8 +169,8 @@ func createSessionStateFromBearerToken(ctx context.Context, rawIDToken string, i Email: claims.Email, User: claims.Subject, PreferredUsername: claims.PreferredUsername, - AccessToken: rawIDToken, - IDToken: rawIDToken, + AccessToken: token, + IDToken: token, RefreshToken: "", ExpiresOn: &idToken.Expiry, } diff --git a/pkg/middleware/jwt_session_test.go b/pkg/middleware/jwt_session_test.go index b9503731..794c8488 100644 --- a/pkg/middleware/jwt_session_test.go +++ b/pkg/middleware/jwt_session_test.go @@ -73,13 +73,20 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` const validToken = "eyJfoobar.eyJfoobar.12345asdf" Context("JwtSessionLoader", func() { - var verifier *oidc.IDTokenVerifier + var verifier middlewareapi.VerifyFunc const nonVerifiedToken = validToken BeforeEach(func() { - keyset := noOpKeySet{} - verifier = oidc.NewVerifier("https://issuer.example.com", keyset, - &oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) + verifier = func(ctx context.Context, token string) (interface{}, error) { + return oidc.NewVerifier( + "https://issuer.example.com", + noOpKeySet{}, + &oidc.Config{ + ClientID: "https://test.myapp.com", + SkipExpiryCheck: true, + }, + ).Verify(ctx, token) + } }) type jwtSessionLoaderTableInput struct { @@ -167,16 +174,23 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` const nonVerifiedToken = validToken BeforeEach(func() { - keyset := noOpKeySet{} - verifier := oidc.NewVerifier("https://issuer.example.com", keyset, - &oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) + verifier := func(ctx context.Context, token string) (interface{}, error) { + return oidc.NewVerifier( + "https://issuer.example.com", + noOpKeySet{}, + &oidc.Config{ + ClientID: "https://test.myapp.com", + SkipExpiryCheck: true, + }, + ).Verify(ctx, token) + } j = &jwtSessionLoader{ jwtRegex: regexp.MustCompile(jwtRegexFormat), sessionLoaders: []middlewareapi.TokenToSessionLoader{ { Verifier: verifier, - TokenToSession: createSessionStateFromBearerToken, + TokenToSession: createSessionFromToken, }, }, } @@ -239,7 +253,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` ) }) - Context("findBearerTokenFromHeader", func() { + Context("findTokenFromHeader", func() { var j *jwtSessionLoader BeforeEach(func() { @@ -256,7 +270,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` DescribeTable("with a header", func(in findBearerTokenFromHeaderTableInput) { - token, err := j.findBearerTokenFromHeader(in.header) + token, err := j.findTokenFromHeader(in.header) if in.expectedErr != nil { Expect(err).To(MatchError(in.expectedErr)) } else { @@ -381,7 +395,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` ) }) - Context("createSessionStateFromBearerToken", func() { + Context("createSessionFromToken", func() { ctx := context.Background() expiresFuture := time.Now().Add(time.Duration(5) * time.Minute) verified := true @@ -403,11 +417,18 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` DescribeTable("when creating a session from an IDToken", func(in createSessionStateTableInput) { - verifier := oidc.NewVerifier( - "https://issuer.example.com", - noOpKeySet{}, - &oidc.Config{ClientID: "asdf1234"}, - ) + verifier := func(ctx context.Context, token string) (interface{}, error) { + oidcVerifier := oidc.NewVerifier( + "https://issuer.example.com", + noOpKeySet{}, + &oidc.Config{ClientID: "asdf1234"}, + ) + + idToken, err := oidcVerifier.Verify(ctx, token) + Expect(err).ToNot(HaveOccurred()) + + return idToken, nil + } key, err := rsa.GenerateKey(rand.Reader, 2048) Expect(err).ToNot(HaveOccurred()) @@ -415,11 +436,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, in.idToken).SignedString(key) Expect(err).ToNot(HaveOccurred()) - // Pass to a dummy Verifier to get an oidc.IDToken from the rawIDToken for our actual test below - idToken, err := verifier.Verify(context.Background(), rawIDToken) - Expect(err).ToNot(HaveOccurred()) - - session, err := createSessionStateFromBearerToken(ctx, rawIDToken, idToken) + session, err := createSessionFromToken(ctx, rawIDToken, verifier) if in.expectedErr != nil { Expect(err).To(MatchError(in.expectedErr)) Expect(session).To(BeNil()) diff --git a/providers/oidc.go b/providers/oidc.go index 7c48c42a..d94f27ce 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -11,6 +11,7 @@ import ( oidc "github.com/coreos/go-oidc" "golang.org/x/oauth2" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" @@ -175,14 +176,24 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok return newSession, nil } -func (p *OIDCProvider) CreateSessionFromBearer(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) { +func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string, verify middleware.VerifyFunc) (*sessions.SessionState, error) { + verifiedToken, err := verify(ctx, token) + if err != nil { + return nil, err + } + + idToken, ok := verifiedToken.(*oidc.IDToken) + if !ok { + return nil, fmt.Errorf("failed to create IDToken from bearer token: %s", token) + } + newSession, err := p.createSessionStateInternal(ctx, idToken, nil) if err != nil { return nil, err } - newSession.AccessToken = rawIDToken - newSession.IDToken = rawIDToken + newSession.AccessToken = token + newSession.IDToken = token newSession.RefreshToken = "" newSession.ExpiresOn = &idToken.Expiry diff --git a/providers/oidc_test.go b/providers/oidc_test.go index cc4cdc8a..429a7fca 100644 --- a/providers/oidc_test.go +++ b/providers/oidc_test.go @@ -347,14 +347,18 @@ func TestCreateSessionStateFromBearerToken(t *testing.T) { rawIDToken, err := newSignedTestIDToken(tc.IDToken) assert.NoError(t, err) - keyset := fakeKeySetStub{} - verifier := oidc.NewVerifier("https://issuer.example.com", keyset, - &oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) + verifyFunc := func(ctx context.Context, token string) (interface{}, error) { + keyset := fakeKeySetStub{} + verifier := oidc.NewVerifier("https://issuer.example.com", keyset, + &oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) - idToken, err := verifier.Verify(context.Background(), rawIDToken) - assert.NoError(t, err) + idToken, err := verifier.Verify(ctx, token) + assert.NoError(t, err) - ss, err := provider.CreateSessionFromBearer(context.Background(), rawIDToken, idToken) + return idToken, nil + } + + ss, err := provider.CreateSessionFromToken(context.Background(), rawIDToken, verifyFunc) assert.NoError(t, err) assert.Equal(t, tc.ExpectedUser, ss.User) diff --git a/providers/provider_default.go b/providers/provider_default.go index 7a8c4e40..ee2e2824 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -8,8 +8,7 @@ import ( "net/url" "time" - "github.com/coreos/go-oidc" - + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests" ) @@ -127,6 +126,6 @@ func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.Ses // CreateSessionStateFromBearerToken should be implemented to allow providers // to convert ID tokens into sessions -func (p *ProviderData) CreateSessionFromBearer(_ context.Context, _ string, _ *oidc.IDToken) (*sessions.SessionState, error) { +func (p *ProviderData) CreateSessionFromToken(_ context.Context, _ string, _ middleware.VerifyFunc) (*sessions.SessionState, error) { return nil, ErrNotImplemented } diff --git a/providers/providers.go b/providers/providers.go index 09abf725..8890a5a7 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -3,7 +3,7 @@ package providers import ( "context" - "github.com/coreos/go-oidc" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" ) @@ -18,7 +18,7 @@ type Provider interface { ValidateSession(ctx context.Context, s *sessions.SessionState) bool GetLoginURL(redirectURI, finalRedirect string) string RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) - CreateSessionFromBearer(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) + CreateSessionFromToken(ctx context.Context, token string, verify middleware.VerifyFunc) (*sessions.SessionState, error) } // New provides a new Provider based on the configured provider string