diff --git a/providers/oidc.go b/providers/oidc.go index 98cefb4b..cdeee3b2 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -206,12 +206,15 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) { idToken, err := p.verifyIDToken(ctx, token) if err != nil { - return nil, fmt.Errorf("could not verify id_token: %v", err) - } - - // IDToken is mandatory in Redeem but optional in Refresh - if idToken == nil && !refresh { - return nil, errors.New("token response did not contain an id_token") + switch err { + case ErrMissingIDToken: + // IDToken is mandatory in Redeem but optional in Refresh + if !refresh { + return nil, errors.New("token response did not contain an id_token") + } + default: + return nil, fmt.Errorf("could not verify id_token: %v", err) + } } ss, err := p.buildSessionFromClaims(idToken) diff --git a/providers/provider_data.go b/providers/provider_data.go index 098e6192..d8c9312b 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -129,9 +129,12 @@ type OIDCClaims struct { func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) { rawIDToken := getIDToken(token) if strings.TrimSpace(rawIDToken) != "" { + if p.Verifier == nil { + return nil, ErrMissingOIDCVerifier + } return p.Verifier.Verify(ctx, rawIDToken) } - return nil, nil + return nil, ErrMissingIDToken } // buildSessionFromClaims uses IDToken claims to populate a fresh SessionState diff --git a/providers/provider_data_test.go b/providers/provider_data_test.go index f94c0db1..80f6ecab 100644 --- a/providers/provider_data_test.go +++ b/providers/provider_data_test.go @@ -137,23 +137,33 @@ func TestProviderData_verifyIDToken(t *testing.T) { testCases := map[string]struct { IDToken *idTokenClaims + Verifier bool ExpectIDToken bool ExpectedError error }{ "Valid ID Token": { IDToken: &defaultIDToken, + Verifier: true, ExpectIDToken: true, ExpectedError: nil, }, "Invalid ID Token": { IDToken: &failureIDToken, + Verifier: true, ExpectIDToken: false, ExpectedError: errors.New("failed to verify signature: the validation failed for subject [123456789]"), }, "Missing ID Token": { IDToken: nil, + Verifier: true, ExpectIDToken: false, - ExpectedError: nil, + ExpectedError: ErrMissingIDToken, + }, + "OIDC Verifier not Configured": { + IDToken: &defaultIDToken, + Verifier: false, + ExpectIDToken: false, + ExpectedError: ErrMissingOIDCVerifier, }, } @@ -170,12 +180,13 @@ func TestProviderData_verifyIDToken(t *testing.T) { }) } - provider := &ProviderData{ - Verifier: oidc.NewVerifier( + provider := &ProviderData{} + if tc.Verifier { + provider.Verifier = oidc.NewVerifier( oidcIssuer, mockJWKS{}, &oidc.Config{ClientID: oidcClientID}, - ), + ) } verified, err := provider.verifyIDToken(context.Background(), token) if err != nil { diff --git a/providers/provider_default.go b/providers/provider_default.go index 012a538c..d3c6d113 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -22,6 +22,14 @@ var ( // code ErrMissingCode = errors.New("missing code") + // ErrMissingIDToken is returned when an oidc.Token does not contain the + // extra `id_token` field for an IDToken. + ErrMissingIDToken = errors.New("missing id_token") + + // ErrMissingOIDCVerifier is returned when a provider didn't set `Verifier` + // but an attempt to call `Verifier.Verify` was about to be made. + ErrMissingOIDCVerifier = errors.New("oidc verifier is not configured") + _ Provider = (*ProviderData)(nil) )