From c81a7ed1976a7e79c1a70ebb10091fe503bb05b6 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sat, 4 Jul 2020 19:19:36 +0100 Subject: [PATCH] Add JWT session loader middleware --- pkg/apis/middleware/session.go | 24 ++ pkg/middleware/jwt_session.go | 186 ++++++++++ pkg/middleware/jwt_session_test.go | 532 +++++++++++++++++++++++++++++ 3 files changed, 742 insertions(+) create mode 100644 pkg/apis/middleware/session.go create mode 100644 pkg/middleware/jwt_session.go create mode 100644 pkg/middleware/jwt_session_test.go diff --git a/pkg/apis/middleware/session.go b/pkg/apis/middleware/session.go new file mode 100644 index 00000000..344ba31e --- /dev/null +++ b/pkg/apis/middleware/session.go @@ -0,0 +1,24 @@ +package middleware + +import ( + "context" + + "github.com/coreos/go-oidc" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" +) + +// TokenToSessionFunc takes a rawIDToken and an idToken and converts it into a +// SessionState. +type TokenToSessionFunc func(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessionsapi.SessionState, error) + +// TokenToSessionLoader pairs a token verifier with the correct converter function +// to convert the ID Token to a SessionState. +type TokenToSessionLoader struct { + // Verfier is used to verify that the ID Token was signed by the claimed issuer + // and that the token has not been tampered with. + Verifier *oidc.IDTokenVerifier + + // TokenToSession converts a rawIDToken and an idToken to a SessionState. + // (Optional) If not set a default basic implementation is used. + TokenToSession TokenToSessionFunc +} diff --git a/pkg/middleware/jwt_session.go b/pkg/middleware/jwt_session.go new file mode 100644 index 00000000..4c3087c7 --- /dev/null +++ b/pkg/middleware/jwt_session.go @@ -0,0 +1,186 @@ +package middleware + +import ( + "context" + "encoding/base64" + "fmt" + "net/http" + "regexp" + "strings" + + "github.com/coreos/go-oidc" + "github.com/justinas/alice" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" +) + +const jwtRegexFormat = `^eyJ[a-zA-Z0-9_-]*\.eyJ[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$` + +func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionLoader) alice.Constructor { + for i, loader := range sessionLoaders { + if loader.TokenToSession == nil { + sessionLoaders[i] = middlewareapi.TokenToSessionLoader{ + Verifier: loader.Verifier, + TokenToSession: createSessionStateFromBearerToken, + } + } + } + + js := &jwtSessionLoader{ + jwtRegex: regexp.MustCompile(jwtRegexFormat), + sessionLoaders: sessionLoaders, + } + return js.loadSession +} + +// jwtSessionLoader is responsible for loading sessions from JWTs in +// Authorization headers. +type jwtSessionLoader struct { + jwtRegex *regexp.Regexp + sessionLoaders []middlewareapi.TokenToSessionLoader +} + +// loadSession attempts to load a session from a JWT stored in an Authorization +// header within the request. +// If no authorization header is found, or the header is invalid, no session +// will be loaded and the request will be passed to the next handler. +// If a session was loaded by a previous handler, it will not be replaced. +func (j *jwtSessionLoader) loadSession(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + scope := GetRequestScope(req) + // If scope is nil, this will panic. + // A scope should always be injected before this handler is called. + if scope.Session != nil { + // The session was already loaded, pass to the next handler + next.ServeHTTP(rw, req) + return + } + + session, err := j.getJwtSession(req) + if err != nil { + logger.Printf("Error retrieving session from token in Authorization header: %v", err) + } + + // Add the session to the scope if it was found + scope.Session = session + next.ServeHTTP(rw, req) + }) +} + +// getJwtSession loads a session based on a JWT token in the authorization header. +// (see the config options skip-jwt-bearer-tokens and extra-jwt-issuers) +func (j *jwtSessionLoader) getJwtSession(req *http.Request) (*sessionsapi.SessionState, error) { + auth := req.Header.Get("Authorization") + if auth == "" { + // No auth header provided, so don't attempt to load a session + return nil, nil + } + + rawBearerToken, err := j.findBearerTokenFromHeader(auth) + if err != nil { + return nil, err + } + + for _, loader := range j.sessionLoaders { + bearerToken, err := loader.Verifier.Verify(req.Context(), rawBearerToken) + if err == nil { + // The token was verified, convert it to a session + return loader.TokenToSession(req.Context(), rawBearerToken, bearerToken) + } + } + + return nil, fmt.Errorf("unable to verify jwt token: %q", req.Header.Get("Authorization")) +} + +// findBearerTokenFromHeader finds a valid JWT token from the Authorization header of a given request. +func (j *jwtSessionLoader) findBearerTokenFromHeader(header string) (string, error) { + tokenType, token, err := splitAuthHeader(header) + if err != nil { + return "", err + } + + if tokenType == "Bearer" && j.jwtRegex.MatchString(token) { + // Found a JWT as a bearer token + return token, nil + } + + if tokenType == "Basic" { + // Check if we have a Bearer token masquerading in Basic + return j.getBasicToken(token) + } + + return "", fmt.Errorf("no valid bearer token found in authorization header") +} + +// getBasicToken tries to extract a token from the basic value provided. +func (j *jwtSessionLoader) getBasicToken(token string) (string, error) { + b, err := base64.StdEncoding.DecodeString(token) + if err != nil { + return "", fmt.Errorf("invalid basic auth token: %v", err) + } + + pair := strings.SplitN(string(b), ":", 2) + if len(pair) != 2 { + return "", fmt.Errorf("invalid format: %q", b) + } + user, password := pair[0], pair[1] + + // check user, user+password, or just password for a token + if j.jwtRegex.MatchString(user) { + // Support blank passwords or magic `x-oauth-basic` passwords - nothing else + if password == "" || password == "x-oauth-basic" { + return user, nil + } + } else if j.jwtRegex.MatchString(password) { + // support passwords and ignore user + return password, nil + } + + return "", fmt.Errorf("invalid basic auth token found in authorization header") +} + +// splitAuthHeader takes the auth header value and splits it into the token type +// and the token value. +func splitAuthHeader(header string) (string, string, error) { + s := strings.Split(header, " ") + if len(s) != 2 { + return "", "", fmt.Errorf("invalid authorization header: %q", header) + } + return s[0], s[1], nil +} + +// createSessionStateFromBearerToken is a default implementation for converting +// a JWT into a session state. +func createSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessionsapi.SessionState, error) { + var claims struct { + Subject string `json:"sub"` + Email string `json:"email"` + Verified *bool `json:"email_verified"` + PreferredUsername string `json:"preferred_username"` + } + + if err := idToken.Claims(&claims); err != nil { + return nil, fmt.Errorf("failed to parse bearer token claims: %v", err) + } + + if claims.Email == "" { + claims.Email = claims.Subject + } + + if claims.Verified != nil && !*claims.Verified { + return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) + } + + newSession := &sessionsapi.SessionState{ + Email: claims.Email, + User: claims.Subject, + PreferredUsername: claims.PreferredUsername, + AccessToken: rawIDToken, + IDToken: rawIDToken, + RefreshToken: "", + ExpiresOn: &idToken.Expiry, + } + + return newSession, nil +} diff --git a/pkg/middleware/jwt_session_test.go b/pkg/middleware/jwt_session_test.go new file mode 100644 index 00000000..f8adf2d3 --- /dev/null +++ b/pkg/middleware/jwt_session_test.go @@ -0,0 +1,532 @@ +package middleware + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "regexp" + "strings" + "time" + + "github.com/coreos/go-oidc" + "github.com/dgrijalva/jwt-go" + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" + sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +type noOpKeySet struct { +} + +func (noOpKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { + splitStrings := strings.Split(jwt, ".") + payloadString := splitStrings[1] + return base64.RawURLEncoding.DecodeString(payloadString) +} + +var _ = Describe("JWT Session Suite", func() { + /* token payload: + { + "sub": "1234567890", + "aud": "https://test.myapp.com", + "name": "John Doe", + "email": "john@example.com", + "iss": "https://issuer.example.com", + "iat": 1553691215, + "exp": 1912151821 + } + */ + const verifiedToken = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." + + "eyJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjoiaHR0cHM6Ly90ZXN0Lm15YXBwLmNvbSIsIm5hbWUiOiJKb2huIERvZSIsImVtY" + + "WlsIjoiam9obkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwiaWF0IjoxNTUzNjkxMj" + + "E1LCJleHAiOjE5MTIxNTE4MjF9." + + "rLVyzOnEldUq_pNkfa-WiV8TVJYWyZCaM2Am_uo8FGg11zD7l-qmz3x1seTvqpH6Y0Ty00fmv6dJnGnC8WMnPXQiodRTfhBSe" + + "OKZMu0HkMD2sg52zlKkbfLTO6ic5VnbVgwjjrB8am_Ta6w7kyFUaB5C1BsIrrLMldkWEhynbb8" + + const verifiedTokenXOAuthBasicBase64 = `ZXlKaGJHY2lPaUpTVXpJMU5pSXNJblI1Y0NJNklrcFhWQ0o5LmV5SnpkV0lpT2lJeE1qTTBOVFkz +T0Rrd0lpd2lZWFZrSWpvaWFIUjBjSE02THk5MFpYTjBMbTE1WVhCd0xtTnZiU0lzSW01aGJXVWlP +aUpLYjJodUlFUnZaU0lzSW1WdFlXbHNJam9pYW05b2JrQmxlR0Z0Y0d4bExtTnZiU0lzSW1semN5 +STZJbWgwZEhCek9pOHZhWE56ZFdWeUxtVjRZVzF3YkdVdVkyOXRJaXdpYVdGMElqb3hOVFV6Tmpr +eE1qRTFMQ0psZUhBaU9qRTVNVEl4TlRFNE1qRjkuckxWeXpPbkVsZFVxX3BOa2ZhLVdpVjhUVkpZ +V3laQ2FNMkFtX3VvOEZHZzExekQ3bC1xbXozeDFzZVR2cXBINlkwVHkwMGZtdjZkSm5HbkM4V01u +UFhRaW9kUlRmaEJTZU9LWk11MEhrTUQyc2c1MnpsS2tiZkxUTzZpYzVWbmJWZ3dqanJCOGFtX1Rh +Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` + + var verifiedSessionExpiry = time.Unix(1912151821, 0) + var verifiedSession = &sessionsapi.SessionState{ + AccessToken: verifiedToken, + IDToken: verifiedToken, + Email: "john@example.com", + User: "1234567890", + ExpiresOn: &verifiedSessionExpiry, + } + + // validToken will pass the token regex so can be used to check token fetching + // is valid. It will not pass the OIDC Verifier however. + const validToken = "eyJfoobar.eyJfoobar.12345asdf" + + Context("JwtSessionLoader", func() { + var verifier *oidc.IDTokenVerifier + const nonVerifiedToken = validToken + + BeforeEach(func() { + keyset := noOpKeySet{} + verifier = oidc.NewVerifier("https://issuer.example.com", keyset, + &oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) + }) + + type jwtSessionLoaderTableInput struct { + authorizationHeader string + existingSession *sessionsapi.SessionState + expectedSession *sessionsapi.SessionState + } + + DescribeTable("with an authorization header", + func(in jwtSessionLoaderTableInput) { + scope := &middlewareapi.RequestScope{ + Session: in.existingSession, + } + + // Set up the request with the authorization header and a request scope + req := httptest.NewRequest("", "/", nil) + req.Header.Set("Authorization", in.authorizationHeader) + contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) + req = req.WithContext(contextWithScope) + + rw := httptest.NewRecorder() + + sessionLoaders := []middlewareapi.TokenToSessionLoader{ + { + Verifier: verifier, + }, + } + + // Create the handler with a next handler that will capture the session + // from the scope + var gotSession *sessionsapi.SessionState + handler := NewJwtSessionLoader(sessionLoaders)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session + })) + handler.ServeHTTP(rw, req) + + Expect(gotSession).To(Equal(in.expectedSession)) + }, + Entry("", jwtSessionLoaderTableInput{ + authorizationHeader: "", + existingSession: nil, + expectedSession: nil, + }), + Entry("abcdef", jwtSessionLoaderTableInput{ + authorizationHeader: "abcdef", + existingSession: nil, + expectedSession: nil, + }), + Entry("abcdef (with existing session)", jwtSessionLoaderTableInput{ + authorizationHeader: "abcdef", + existingSession: &sessionsapi.SessionState{User: "user"}, + expectedSession: &sessionsapi.SessionState{User: "user"}, + }), + Entry("Bearer ", jwtSessionLoaderTableInput{ + authorizationHeader: fmt.Sprintf("Bearer %s", verifiedToken), + existingSession: nil, + expectedSession: verifiedSession, + }), + Entry("Bearer ", jwtSessionLoaderTableInput{ + authorizationHeader: fmt.Sprintf("Bearer %s", nonVerifiedToken), + existingSession: nil, + expectedSession: nil, + }), + Entry("Bearer (with existing session)", jwtSessionLoaderTableInput{ + authorizationHeader: fmt.Sprintf("Bearer %s", verifiedToken), + existingSession: &sessionsapi.SessionState{User: "user"}, + expectedSession: &sessionsapi.SessionState{User: "user"}, + }), + Entry("Basic Base64(:) (No password)", jwtSessionLoaderTableInput{ + authorizationHeader: "Basic ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6", + existingSession: nil, + expectedSession: nil, + }), + Entry("Basic Base64(:x-oauth-basic) (Sentinel password)", jwtSessionLoaderTableInput{ + authorizationHeader: fmt.Sprintf("Basic %s", verifiedTokenXOAuthBasicBase64), + existingSession: nil, + expectedSession: verifiedSession, + }), + ) + + }) + + Context("getJWTSession", func() { + var j *jwtSessionLoader + const nonVerifiedToken = validToken + + BeforeEach(func() { + keyset := noOpKeySet{} + verifier := oidc.NewVerifier("https://issuer.example.com", keyset, + &oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) + + j = &jwtSessionLoader{ + jwtRegex: regexp.MustCompile(jwtRegexFormat), + sessionLoaders: []middlewareapi.TokenToSessionLoader{ + { + Verifier: verifier, + TokenToSession: createSessionStateFromBearerToken, + }, + }, + } + }) + + type getJWTSessionTableInput struct { + authorizationHeader string + expectedErr error + expectedSession *sessionsapi.SessionState + } + + DescribeTable("with an authorization header", + func(in getJWTSessionTableInput) { + req := httptest.NewRequest("", "/", nil) + req.Header.Set("Authorization", in.authorizationHeader) + + session, err := j.getJwtSession(req) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(session).To(Equal(in.expectedSession)) + }, + Entry("", getJWTSessionTableInput{ + authorizationHeader: "", + expectedErr: nil, + expectedSession: nil, + }), + Entry("abcdef", getJWTSessionTableInput{ + authorizationHeader: "abcdef", + expectedErr: errors.New("invalid authorization header: \"abcdef\""), + expectedSession: nil, + }), + Entry("Bearer abcdef", getJWTSessionTableInput{ + authorizationHeader: "Bearer abcdef", + expectedErr: errors.New("no valid bearer token found in authorization header"), + expectedSession: nil, + }), + Entry("Bearer ", getJWTSessionTableInput{ + authorizationHeader: fmt.Sprintf("Bearer %s", nonVerifiedToken), + expectedErr: errors.New("unable to verify jwt token: \"Bearer eyJfoobar.eyJfoobar.12345asdf\""), + expectedSession: nil, + }), + Entry("Bearer ", getJWTSessionTableInput{ + authorizationHeader: fmt.Sprintf("Bearer %s", verifiedToken), + expectedErr: nil, + expectedSession: verifiedSession, + }), + Entry("Basic Base64(:) (No password)", getJWTSessionTableInput{ + authorizationHeader: "Basic ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6", + expectedErr: errors.New("unable to verify jwt token: \"Basic ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6\""), + expectedSession: nil, + }), + Entry("Basic Base64(:x-oauth-basic) (Sentinel password)", getJWTSessionTableInput{ + authorizationHeader: fmt.Sprintf("Basic %s", verifiedTokenXOAuthBasicBase64), + expectedErr: nil, + expectedSession: verifiedSession, + }), + ) + }) + + Context("findBearerTokenFromHeader", func() { + var j *jwtSessionLoader + + BeforeEach(func() { + j = &jwtSessionLoader{ + jwtRegex: regexp.MustCompile(jwtRegexFormat), + } + }) + + type findBearerTokenFromHeaderTableInput struct { + header string + expectedErr error + expectedToken string + } + + DescribeTable("with a header", + func(in findBearerTokenFromHeaderTableInput) { + token, err := j.findBearerTokenFromHeader(in.header) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(token).To(Equal(in.expectedToken)) + }, + Entry("Bearer", findBearerTokenFromHeaderTableInput{ + header: "Bearer", + expectedErr: errors.New("invalid authorization header: \"Bearer\""), + expectedToken: "", + }), + Entry("Bearer abc def", findBearerTokenFromHeaderTableInput{ + header: "Bearer abc def", + expectedErr: errors.New("invalid authorization header: \"Bearer abc def\""), + expectedToken: "", + }), + Entry("Bearer abcdef", findBearerTokenFromHeaderTableInput{ + header: "Bearer abcdef", + expectedErr: errors.New("no valid bearer token found in authorization header"), + expectedToken: "", + }), + Entry("Bearer ", findBearerTokenFromHeaderTableInput{ + header: fmt.Sprintf("Bearer %s", validToken), + expectedErr: nil, + expectedToken: validToken, + }), + Entry("Basic invalid-base64", findBearerTokenFromHeaderTableInput{ + header: "Basic invalid-base64", + expectedErr: errors.New("invalid basic auth token: illegal base64 data at input byte 7"), + expectedToken: "", + }), + Entry("Basic Base64(:) (No password)", findBearerTokenFromHeaderTableInput{ + header: "Basic ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6", + expectedErr: nil, + expectedToken: validToken, + }), + Entry("Basic Base64(:x-oauth-basic) (Sentinel password)", findBearerTokenFromHeaderTableInput{ + header: "Basic ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6eC1vYXV0aC1iYXNpYw==", + expectedErr: nil, + expectedToken: validToken, + }), + Entry("Basic Base64(any-user:) (Matching password)", findBearerTokenFromHeaderTableInput{ + header: "Basic YW55LXVzZXI6ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY=", + expectedErr: nil, + expectedToken: validToken, + }), + Entry("Basic Base64(any-user:any-password) (No matches)", findBearerTokenFromHeaderTableInput{ + header: "Basic YW55LXVzZXI6YW55LXBhc3N3b3Jk", + expectedErr: errors.New("invalid basic auth token found in authorization header"), + expectedToken: "", + }), + Entry("Basic Base64(any-user any-password) (Invalid format)", findBearerTokenFromHeaderTableInput{ + header: "Basic YW55LXVzZXIgYW55LXBhc3N3b3Jk", + expectedErr: errors.New("invalid format: \"any-user any-password\""), + expectedToken: "", + }), + Entry("Something ", findBearerTokenFromHeaderTableInput{ + header: fmt.Sprintf("Something %s", validToken), + expectedErr: errors.New("no valid bearer token found in authorization header"), + expectedToken: "", + }), + ) + + }) + + Context("getBasicToken", func() { + var j *jwtSessionLoader + + BeforeEach(func() { + j = &jwtSessionLoader{ + jwtRegex: regexp.MustCompile(jwtRegexFormat), + } + }) + + type getBasicTokenTableInput struct { + token string + expectedErr error + expectedToken string + } + + DescribeTable("with a token", + func(in getBasicTokenTableInput) { + token, err := j.getBasicToken(in.token) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(token).To(Equal(in.expectedToken)) + }, + Entry("invalid-base64", getBasicTokenTableInput{ + token: "invalid-base64", + expectedErr: errors.New("invalid basic auth token: illegal base64 data at input byte 7"), + expectedToken: "", + }), + Entry("Base64(:) (No password)", getBasicTokenTableInput{ + token: "ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6", + expectedErr: nil, + expectedToken: validToken, + }), + Entry("Base64(:x-oauth-basic) (Sentinel password)", getBasicTokenTableInput{ + token: "ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6eC1vYXV0aC1iYXNpYw==", + expectedErr: nil, + expectedToken: validToken, + }), + Entry("Base64(any-user:) (Matching password)", getBasicTokenTableInput{ + token: "YW55LXVzZXI6ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY=", + expectedErr: nil, + expectedToken: validToken, + }), + Entry("Base64(any-user:any-password) (No matches)", getBasicTokenTableInput{ + token: "YW55LXVzZXI6YW55LXBhc3N3b3Jk", + expectedErr: errors.New("invalid basic auth token found in authorization header"), + expectedToken: "", + }), + Entry("Base64(any-user any-password) (Invalid format)", getBasicTokenTableInput{ + token: "YW55LXVzZXIgYW55LXBhc3N3b3Jk", + expectedErr: errors.New("invalid format: \"any-user any-password\""), + expectedToken: "", + }), + ) + }) + + Context("splitAuthHeader", func() { + type splitAuthTableInput struct { + header string + expectedErr error + expectedTokenType string + expectedTokenValue string + } + + DescribeTable("with a header value", + func(in splitAuthTableInput) { + tt, tv, err := splitAuthHeader(in.header) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + } else { + Expect(err).ToNot(HaveOccurred()) + } + Expect(tt).To(Equal(in.expectedTokenType)) + Expect(tv).To(Equal(in.expectedTokenValue)) + }, + Entry("Bearer abcdef", splitAuthTableInput{ + header: "Bearer abcdef", + expectedErr: nil, + expectedTokenType: "Bearer", + expectedTokenValue: "abcdef", + }), + Entry("Bearer", splitAuthTableInput{ + header: "Bearer", + expectedErr: errors.New("invalid authorization header: \"Bearer\""), + expectedTokenType: "", + expectedTokenValue: "", + }), + Entry("Bearer abc def", splitAuthTableInput{ + header: "Bearer abc def", + expectedErr: errors.New("invalid authorization header: \"Bearer abc def\""), + expectedTokenType: "", + expectedTokenValue: "", + }), + ) + }) + + Context("createSessionStateFromBearerToken", func() { + ctx := context.Background() + expiresFuture := time.Now().Add(time.Duration(5) * time.Minute) + verified := true + notVerified := false + + type idTokenClaims struct { + Email string `json:"email,omitempty"` + Verified *bool `json:"email_verified,omitempty"` + jwt.StandardClaims + } + + type createSessionStateTableInput struct { + idToken idTokenClaims + expectedErr error + expectedUser string + expectedEmail string + expectedExpires *time.Time + } + + DescribeTable("when creating a session from an IDToken", + func(in createSessionStateTableInput) { + verifier := oidc.NewVerifier( + "https://issuer.example.com", + noOpKeySet{}, + &oidc.Config{ClientID: "asdf1234"}, + ) + + key, err := rsa.GenerateKey(rand.Reader, 2048) + Expect(err).ToNot(HaveOccurred()) + + rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, in.idToken).SignedString(key) + Expect(err).ToNot(HaveOccurred()) + + // Pass to a dummy Verifier to get an oidc.IDToken from the rawIDToken for our actual test below + idToken, err := verifier.Verify(context.Background(), rawIDToken) + Expect(err).ToNot(HaveOccurred()) + + session, err := createSessionStateFromBearerToken(ctx, rawIDToken, idToken) + if in.expectedErr != nil { + Expect(err).To(MatchError(in.expectedErr)) + Expect(session).To(BeNil()) + return + } + + Expect(err).ToNot(HaveOccurred()) + Expect(session.AccessToken).To(Equal(rawIDToken)) + Expect(session.IDToken).To(Equal(rawIDToken)) + Expect(session.User).To(Equal(in.expectedUser)) + Expect(session.Email).To(Equal(in.expectedEmail)) + Expect(session.ExpiresOn.Unix()).To(Equal(in.expectedExpires.Unix())) + Expect(session.RefreshToken).To(BeEmpty()) + Expect(session.PreferredUsername).To(BeEmpty()) + }, + Entry("with no email", createSessionStateTableInput{ + idToken: idTokenClaims{ + StandardClaims: jwt.StandardClaims{ + Audience: "asdf1234", + ExpiresAt: expiresFuture.Unix(), + Id: "id-some-id", + IssuedAt: time.Now().Unix(), + Issuer: "https://issuer.example.com", + NotBefore: 0, + Subject: "123456789", + }, + }, + expectedErr: nil, + expectedUser: "123456789", + expectedEmail: "123456789", + expectedExpires: &expiresFuture, + }), + Entry("with a verified email", createSessionStateTableInput{ + idToken: idTokenClaims{ + StandardClaims: jwt.StandardClaims{ + Audience: "asdf1234", + ExpiresAt: expiresFuture.Unix(), + Id: "id-some-id", + IssuedAt: time.Now().Unix(), + Issuer: "https://issuer.example.com", + NotBefore: 0, + Subject: "123456789", + }, + Email: "foo@example.com", + Verified: &verified, + }, + expectedErr: nil, + expectedUser: "123456789", + expectedEmail: "foo@example.com", + expectedExpires: &expiresFuture, + }), + Entry("with a non-verified email", createSessionStateTableInput{ + idToken: idTokenClaims{ + StandardClaims: jwt.StandardClaims{ + Audience: "asdf1234", + ExpiresAt: expiresFuture.Unix(), + Id: "id-some-id", + IssuedAt: time.Now().Unix(), + Issuer: "https://issuer.example.com", + NotBefore: 0, + Subject: "123456789", + }, + Email: "foo@example.com", + Verified: ¬Verified, + }, + expectedErr: errors.New("email in id_token (foo@example.com) isn't verified"), + }), + ) + }) +})