This commit is contained in:
Francesco Pasqualini 2026-02-28 14:06:04 +01:00 committed by GitHub
commit c5d62d6626
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 98 additions and 2 deletions

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/justinas/alice"
@ -31,6 +32,34 @@ const (
sessionRefreshRetryPeriod = 10 * time.Millisecond
)
// isFatalRefreshError checks if a refresh error indicates a revoked or
// non-existent session that should be immediately invalidated.
// Fatal errors indicate the session is no longer valid at the provider level.
// Non-fatal errors (network issues, timeouts) should not invalidate the session.
//
// Only checks standard OAuth2 error codes (RFC 6749 Section 5.2).
// Does NOT check error_description strings as they are optional and provider-specific.
func isFatalRefreshError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
// Only check standard OAuth2 error codes (RFC 6749 Section 5.2)
// Do NOT check error_description strings as they are optional and provider-specific
fatalErrors := []string{
"invalid_grant", // refresh token revoked, expired, or session terminated
"invalid_client", // client credentials no longer valid
}
for _, fe := range fatalErrors {
if strings.Contains(errStr, fe) {
return true
}
}
return false
}
// StoredSessionLoaderOptions contains all of the requirements to construct
// a stored session loader.
// All options must be provided.
@ -188,9 +217,25 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
// We are holding the lock and the session needs a refresh
logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age())
if err := s.refreshSession(rw, req, session); err != nil {
// If a preemptive refresh fails, we still keep the session
// if validateSession succeeds.
logger.Errorf("Unable to refresh session: %v", err)
// Check if this is a fatal error that indicates the session is revoked
// or no longer valid at the provider level
if isFatalRefreshError(err) {
logger.Printf("Fatal refresh error detected (session revoked or invalid), clearing session for user: %s", session.User)
// Clear the session from storage (Redis) and remove the cookie
clearErr := s.store.Clear(rw, req)
if clearErr != nil {
logger.Errorf("Error clearing session: %v", clearErr)
}
// Return error immediately to force re-authentication
return fmt.Errorf("session invalidated due to fatal refresh error: %w", err)
}
// For non-fatal errors (network issues, timeouts), keep the session
// and let validateSession determine if it's still usable
}
// Validate all sessions after any Redeem/Refresh operation (fail or success)

View File

@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
@ -801,3 +802,53 @@ func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro
func (f *fakeSessionStore) VerifyConnection(_ context.Context) error {
return nil
}
// TestIsFatalRefreshError tests the isFatalRefreshError function to ensure
// it correctly identifies fatal OAuth2 errors that should invalidate a session.
func TestIsFatalRefreshError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "invalid_grant error",
err: fmt.Errorf("failed to get token: oauth2: \"invalid_grant\" \"Session not active\""),
expected: true,
},
{
name: "invalid_client error",
err: fmt.Errorf("invalid_client: client not found"),
expected: true,
},
{
name: "network timeout - not fatal",
err: fmt.Errorf("Post \"https://keycloak/token\": dial tcp: connect: connection refused"),
expected: false,
},
{
name: "server error - not fatal",
err: fmt.Errorf("unexpected status code 500"),
expected: false,
},
{
name: "generic refresh error - not fatal",
err: fmt.Errorf("error refreshing tokens: context deadline exceeded"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isFatalRefreshError(tt.err)
if result != tt.expected {
t.Errorf("isFatalRefreshError(%v) = %v, want %v", tt.err, result, tt.expected)
}
})
}
}