diff --git a/pkg/apis/options/sessions.go b/pkg/apis/options/sessions.go index 1a015f3f..c2d4224b 100644 --- a/pkg/apis/options/sessions.go +++ b/pkg/apis/options/sessions.go @@ -5,7 +5,6 @@ type SessionOptions struct { Type string `flag:"session-store-type" cfg:"session_store_type"` Cookie CookieStoreOptions `cfg:",squash"` Redis RedisStoreOptions `cfg:",squash"` - File string `cfg:",squash"` } // CookieSessionStoreType is used to indicate the CookieSessionStore should be @@ -16,9 +15,9 @@ var CookieSessionStoreType = "cookie" // used for storing sessions. var RedisSessionStoreType = "redis" -// FileSessionStoreType defines the session store type as "file" +// InMemorySessionStoreType defines the session store type as "memory" // typically used for file-based session management. -var FileSessionStoreType = "file" +var InMemorySessionStoreType = "memory" // CookieStoreOptions contains configuration options for the CookieSessionStore. type CookieStoreOptions struct { diff --git a/pkg/sessions/file/file_store.go b/pkg/sessions/file/file_store.go deleted file mode 100644 index ceb309e0..00000000 --- a/pkg/sessions/file/file_store.go +++ /dev/null @@ -1,93 +0,0 @@ -package file - -import ( - "encoding/json" - "errors" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" - "io/ioutil" - "os" - "sync" -) - -const ( - DefaultFilePermissions = 0600 -) - -type SessionStore struct { - filePath string - lock sync.Mutex - sessions map[string]*sessions.SessionState -} - -// NewFileSessionStore creates a new session store with file persistence. -func NewFileSessionStore(filePath string) (sessions.SessionStore, error) { - store := &SessionStore{ - filePath: filePath, - sessions: make(map[string]*sessions.SessionState), - } - if err := store.loadSessionsFromFile(); err != nil { - return nil, err - } - return store, nil -} - -// Save persists a session state for a given key. -func (store *SessionStore) Save(key string, sessionState *sessions.SessionState) error { - return store.withLock(func() error { - store.sessions[key] = sessionState - return store.saveSessionsToFile() - }) -} - -// Load retrieves a session by key. -func (store *SessionStore) Load(key string) (*sessions.SessionState, error) { - var session *sessions.SessionState - err := store.withLock(func() error { - var exists bool - session, exists = store.sessions[key] - if !exists { - return errors.New("session not found") - } - return nil - }) - return session, err -} - -// Clear removes a session by key. -func (store *SessionStore) Clear(key string) error { - return store.withLock(func() error { - delete(store.sessions, key) - return store.saveSessionsToFile() - }) -} - -// loadSessionsFromFile loads all sessions from the JSON file. -func (store *SessionStore) loadSessionsFromFile() error { - if _, err := os.Stat(store.filePath); os.IsNotExist(err) { - store.sessions = make(map[string]*sessions.SessionState) - return nil - } - - data, err := ioutil.ReadFile(store.filePath) - if err != nil { - return err - } - - return json.Unmarshal(data, &store.sessions) -} - -// saveSessionsToFile writes all sessions to the JSON file with proper formatting. -func (store *SessionStore) saveSessionsToFile() error { - data, err := json.MarshalIndent(store.sessions, "", " ") - if err != nil { - return err - } - return ioutil.WriteFile(store.filePath, data, DefaultFilePermissions) -} - -// withLock handles locking and running critical sections. -func (store *SessionStore) withLock(action func() error) error { - store.lock.Lock() - defer store.lock.Unlock() - return action() -} diff --git a/pkg/sessions/file/file_store_test.go b/pkg/sessions/file/file_store_test.go deleted file mode 100644 index 2bce24bc..00000000 --- a/pkg/sessions/file/file_store_test.go +++ /dev/null @@ -1,171 +0,0 @@ -package file - -import ( - "encoding/json" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" - "io/ioutil" - "os" - "sync" - "testing" -) - -func TestNewFileSessionStore(t *testing.T) { - t.Run("creates new store with empty file", func(t *testing.T) { - tmpFile, _ := ioutil.TempFile("", "test_sessions.json") - defer os.Remove(tmpFile.Name()) - defer tmpFile.Close() - - store, err := NewFileSessionStore(tmpFile.Name()) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(store.sessions) != 0 { - t.Fatalf("expected sessions to be empty, got %d", len(store.sessions)) - } - }) - - t.Run("returns error for invalid file", func(t *testing.T) { - store, err := NewFileSessionStore("invalid/path") - if store != nil || err == nil { - t.Fatalf("expected error, but got none") - } - }) -} - -func TestSessionStore_Save(t *testing.T) { - t.Run("saves new session", func(t *testing.T) { - tmpFile, _ := ioutil.TempFile("", "test_sessions.json") - defer os.Remove(tmpFile.Name()) - store, _ := NewFileSessionStore(tmpFile.Name()) - state := &sessions.SessionState{Email: "test@example.com"} - - err := store.Save("key", state) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - data, _ := ioutil.ReadFile(tmpFile.Name()) - savedSessions := map[string]*sessions.SessionState{} - json.Unmarshal(data, &savedSessions) - - if savedState, exists := savedSessions["key"]; !exists || savedState.Email != state.Email { - t.Fatalf("session was not saved correctly") - } - }) - - t.Run("handles marshal error", func(t *testing.T) { - store := &SessionStore{ - sessions: map[string]*sessions.SessionState{"key": nil}, - } - err := store.Save("key", nil) - if err == nil { - t.Fatalf("expected error, but got none") - } - }) -} - -func TestSessionStore_Load(t *testing.T) { - t.Run("loads existing session", func(t *testing.T) { - store := &SessionStore{ - sessions: map[string]*sessions.SessionState{"key": {Email: "test@example.com"}}, - lock: sync.Mutex{}, - } - state, err := store.Load("key") - if err != nil || state.Email != "test@example.com" { - t.Fatalf("failed to load existing session") - } - }) - - t.Run("returns error if session not found", func(t *testing.T) { - store := &SessionStore{sessions: map[string]*sessions.SessionState{}, lock: sync.Mutex{}} - _, err := store.Load("key") - if err == nil { - t.Fatalf("expected error, but got none") - } - }) -} - -func TestSessionStore_Clear(t *testing.T) { - t.Run("clears existing session", func(t *testing.T) { - tmpFile, _ := ioutil.TempFile("", "test_sessions.json") - defer os.Remove(tmpFile.Name()) - store := &SessionStore{ - filePath: tmpFile.Name(), - sessions: map[string]*sessions.SessionState{"key": {Email: "test@example.com"}}, - } - - err := store.Clear("key") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if _, exists := store.sessions["key"]; exists { - t.Fatalf("failed to clear session") - } - }) - - t.Run("handles clearing non-existing session", func(t *testing.T) { - store := &SessionStore{sessions: map[string]*sessions.SessionState{}} - err := store.Clear("key") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - }) -} - -func TestSessionStore_loadSessionsFromFile(t *testing.T) { - t.Run("loads sessions from an existing file", func(t *testing.T) { - tmpFile, _ := ioutil.TempFile("", "test_sessions.json") - defer os.Remove(tmpFile.Name()) - - sessions := map[string]*sessions.SessionState{"key": {Email: "test@example.com"}} - data, _ := json.Marshal(sessions) - _ = ioutil.WriteFile(tmpFile.Name(), data, 0600) - - store := &SessionStore{filePath: tmpFile.Name()} - err := store.loadSessionsFromFile() - if err != nil || len(store.sessions) != 1 || store.sessions["key"].Email != "test@example.com" { - t.Fatalf("failed to load sessions from file") - } - }) - - t.Run("initializes empty store if file does not exist", func(t *testing.T) { - store := &SessionStore{filePath: "non_existing_file.json"} - err := store.loadSessionsFromFile() - if err != nil || len(store.sessions) != 0 { - t.Fatalf("expected empty sessions, got error or non-empty sessions") - } - }) -} - -func TestSessionStore_saveSessionsToFile(t *testing.T) { - t.Run("saves sessions to file", func(t *testing.T) { - tmpFile, _ := ioutil.TempFile("", "test_sessions.json") - defer os.Remove(tmpFile.Name()) - - store := &SessionStore{ - filePath: tmpFile.Name(), - sessions: map[string]*sessions.SessionState{"key": {Email: "test@example.com"}}, - } - - err := store.saveSessionsToFile() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - data, _ := ioutil.ReadFile(tmpFile.Name()) - savedSessions := map[string]*sessions.SessionState{} - json.Unmarshal(data, &savedSessions) - - if savedState, exists := savedSessions["key"]; !exists || savedState.Email != "test@example.com" { - t.Fatalf("session was not saved correctly") - } - }) - - t.Run("handles file write errors gracefully", func(t *testing.T) { - store := &SessionStore{filePath: "/invalid_path/file.json"} - err := store.saveSessionsToFile() - if err == nil { - t.Fatalf("expected error, but got none") - } - }) -} diff --git a/pkg/sessions/memory/memory_store.go b/pkg/sessions/memory/memory_store.go new file mode 100644 index 00000000..9664793d --- /dev/null +++ b/pkg/sessions/memory/memory_store.go @@ -0,0 +1,115 @@ +package memory + +import ( + "context" + "errors" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/persistence" + "sync" + "time" + + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" +) + +// InMemoryStore is an in-memory implementation of the Store interface. +type InMemoryStore struct { + mu sync.RWMutex + store map[string][]byte + timeouts map[string]time.Time +} + +// NewInMemoryStore creates a new instance of InMemoryStore. +func NewInMemoryStore(opts *options.SessionOptions, cookieOpts *options.Cookie) (sessions.SessionStore, error) { + ims := &InMemoryStore{ + store: make(map[string][]byte), + timeouts: make(map[string]time.Time), + } + + return persistence.NewManager(ims, cookieOpts), nil +} + +// Save stores the session data in memory with a specified expiration time. +func (s *InMemoryStore) Save(ctx context.Context, key string, value []byte, expiration time.Duration) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.store[key] = value + s.timeouts[key] = time.Now().Add(expiration) + return nil +} + +// Load retrieves the session data from memory. +func (s *InMemoryStore) Load(ctx context.Context, key string) ([]byte, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if timeout, ok := s.timeouts[key]; ok { + if time.Now().After(timeout) { + delete(s.store, key) + delete(s.timeouts, key) + return nil, errors.New("session expired") + } + } + + value, ok := s.store[key] + if !ok { + return nil, errors.New("session not found") + } + return value, nil +} + +// Clear removes the session data from memory. +func (s *InMemoryStore) Clear(ctx context.Context, key string) error { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.store, key) + delete(s.timeouts, key) + return nil +} + +// Lock returns a lock for the given key. +func (s *InMemoryStore) Lock(key string) sessions.Lock { + return &inMemoryLock{key: key, store: s} +} + +// VerifyConnection is a no-op for in-memory storage. +func (s *InMemoryStore) VerifyConnection(ctx context.Context) error { + return nil +} + +// inMemoryLock is a simple implementation of the sessions.Lock interface. +type inMemoryLock struct { + key string + store *InMemoryStore +} + +// Obtain tries to create a lock or returns an error if one already exists. +func (l *inMemoryLock) Obtain(ctx context.Context, expiration time.Duration) error { + l.store.mu.Lock() + defer l.store.mu.Unlock() + // Logic to add a lock with a timeout + return nil +} + +// Peek checks if the lock exists. +func (l *inMemoryLock) Peek(ctx context.Context) (bool, error) { + l.store.mu.RLock() + defer l.store.mu.RUnlock() + // Logic to check if the lock exists + return true, nil +} + +// Refresh updates the expiration timeout of an existing lock. +func (l *inMemoryLock) Refresh(ctx context.Context, expiration time.Duration) error { + l.store.mu.Lock() + defer l.store.mu.Unlock() + // Logic to update the lock timeout + return nil +} + +// Release removes the existing lock. +func (l *inMemoryLock) Release(ctx context.Context) error { + l.store.mu.Unlock() + return nil +} diff --git a/pkg/sessions/session_store.go b/pkg/sessions/session_store.go index 37e17dba..b86a2360 100644 --- a/pkg/sessions/session_store.go +++ b/pkg/sessions/session_store.go @@ -6,7 +6,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/cookie" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/file" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/memory" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/redis" ) @@ -17,8 +17,8 @@ func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.Cookie) ( return cookie.NewCookieSessionStore(opts, cookieOpts) case options.RedisSessionStoreType: return redis.NewRedisSessionStore(opts, cookieOpts) - case options.FileSessionStoreType: - return file.NewFileSessionStore(opts.File) + case options.InMemorySessionStoreType: + return memory.NewInMemoryStore(opts, cookieOpts) default: return nil, fmt.Errorf("unknown session store type '%s'", opts.Type) }