diff --git a/pkg/apis/options/sessions.go b/pkg/apis/options/sessions.go index c90c0ac2..1a015f3f 100644 --- a/pkg/apis/options/sessions.go +++ b/pkg/apis/options/sessions.go @@ -5,6 +5,7 @@ type SessionOptions struct { Type string `flag:"session-store-type" cfg:"session_store_type"` Cookie CookieStoreOptions `cfg:",squash"` Redis RedisStoreOptions `cfg:",squash"` + File string `cfg:",squash"` } // CookieSessionStoreType is used to indicate the CookieSessionStore should be @@ -15,6 +16,10 @@ var CookieSessionStoreType = "cookie" // used for storing sessions. var RedisSessionStoreType = "redis" +// FileSessionStoreType defines the session store type as "file" +// typically used for file-based session management. +var FileSessionStoreType = "file" + // CookieStoreOptions contains configuration options for the CookieSessionStore. type CookieStoreOptions struct { Minimal bool `flag:"session-cookie-minimal" cfg:"session_cookie_minimal"` diff --git a/pkg/sessions/file/file_store.go b/pkg/sessions/file/file_store.go new file mode 100644 index 00000000..ceb309e0 --- /dev/null +++ b/pkg/sessions/file/file_store.go @@ -0,0 +1,93 @@ +package file + +import ( + "encoding/json" + "errors" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + "io/ioutil" + "os" + "sync" +) + +const ( + DefaultFilePermissions = 0600 +) + +type SessionStore struct { + filePath string + lock sync.Mutex + sessions map[string]*sessions.SessionState +} + +// NewFileSessionStore creates a new session store with file persistence. +func NewFileSessionStore(filePath string) (sessions.SessionStore, error) { + store := &SessionStore{ + filePath: filePath, + sessions: make(map[string]*sessions.SessionState), + } + if err := store.loadSessionsFromFile(); err != nil { + return nil, err + } + return store, nil +} + +// Save persists a session state for a given key. +func (store *SessionStore) Save(key string, sessionState *sessions.SessionState) error { + return store.withLock(func() error { + store.sessions[key] = sessionState + return store.saveSessionsToFile() + }) +} + +// Load retrieves a session by key. +func (store *SessionStore) Load(key string) (*sessions.SessionState, error) { + var session *sessions.SessionState + err := store.withLock(func() error { + var exists bool + session, exists = store.sessions[key] + if !exists { + return errors.New("session not found") + } + return nil + }) + return session, err +} + +// Clear removes a session by key. +func (store *SessionStore) Clear(key string) error { + return store.withLock(func() error { + delete(store.sessions, key) + return store.saveSessionsToFile() + }) +} + +// loadSessionsFromFile loads all sessions from the JSON file. +func (store *SessionStore) loadSessionsFromFile() error { + if _, err := os.Stat(store.filePath); os.IsNotExist(err) { + store.sessions = make(map[string]*sessions.SessionState) + return nil + } + + data, err := ioutil.ReadFile(store.filePath) + if err != nil { + return err + } + + return json.Unmarshal(data, &store.sessions) +} + +// saveSessionsToFile writes all sessions to the JSON file with proper formatting. +func (store *SessionStore) saveSessionsToFile() error { + data, err := json.MarshalIndent(store.sessions, "", " ") + if err != nil { + return err + } + return ioutil.WriteFile(store.filePath, data, DefaultFilePermissions) +} + +// withLock handles locking and running critical sections. +func (store *SessionStore) withLock(action func() error) error { + store.lock.Lock() + defer store.lock.Unlock() + return action() +} diff --git a/pkg/sessions/file/file_store_test.go b/pkg/sessions/file/file_store_test.go new file mode 100644 index 00000000..2bce24bc --- /dev/null +++ b/pkg/sessions/file/file_store_test.go @@ -0,0 +1,171 @@ +package file + +import ( + "encoding/json" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + "io/ioutil" + "os" + "sync" + "testing" +) + +func TestNewFileSessionStore(t *testing.T) { + t.Run("creates new store with empty file", func(t *testing.T) { + tmpFile, _ := ioutil.TempFile("", "test_sessions.json") + defer os.Remove(tmpFile.Name()) + defer tmpFile.Close() + + store, err := NewFileSessionStore(tmpFile.Name()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(store.sessions) != 0 { + t.Fatalf("expected sessions to be empty, got %d", len(store.sessions)) + } + }) + + t.Run("returns error for invalid file", func(t *testing.T) { + store, err := NewFileSessionStore("invalid/path") + if store != nil || err == nil { + t.Fatalf("expected error, but got none") + } + }) +} + +func TestSessionStore_Save(t *testing.T) { + t.Run("saves new session", func(t *testing.T) { + tmpFile, _ := ioutil.TempFile("", "test_sessions.json") + defer os.Remove(tmpFile.Name()) + store, _ := NewFileSessionStore(tmpFile.Name()) + state := &sessions.SessionState{Email: "test@example.com"} + + err := store.Save("key", state) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, _ := ioutil.ReadFile(tmpFile.Name()) + savedSessions := map[string]*sessions.SessionState{} + json.Unmarshal(data, &savedSessions) + + if savedState, exists := savedSessions["key"]; !exists || savedState.Email != state.Email { + t.Fatalf("session was not saved correctly") + } + }) + + t.Run("handles marshal error", func(t *testing.T) { + store := &SessionStore{ + sessions: map[string]*sessions.SessionState{"key": nil}, + } + err := store.Save("key", nil) + if err == nil { + t.Fatalf("expected error, but got none") + } + }) +} + +func TestSessionStore_Load(t *testing.T) { + t.Run("loads existing session", func(t *testing.T) { + store := &SessionStore{ + sessions: map[string]*sessions.SessionState{"key": {Email: "test@example.com"}}, + lock: sync.Mutex{}, + } + state, err := store.Load("key") + if err != nil || state.Email != "test@example.com" { + t.Fatalf("failed to load existing session") + } + }) + + t.Run("returns error if session not found", func(t *testing.T) { + store := &SessionStore{sessions: map[string]*sessions.SessionState{}, lock: sync.Mutex{}} + _, err := store.Load("key") + if err == nil { + t.Fatalf("expected error, but got none") + } + }) +} + +func TestSessionStore_Clear(t *testing.T) { + t.Run("clears existing session", func(t *testing.T) { + tmpFile, _ := ioutil.TempFile("", "test_sessions.json") + defer os.Remove(tmpFile.Name()) + store := &SessionStore{ + filePath: tmpFile.Name(), + sessions: map[string]*sessions.SessionState{"key": {Email: "test@example.com"}}, + } + + err := store.Clear("key") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, exists := store.sessions["key"]; exists { + t.Fatalf("failed to clear session") + } + }) + + t.Run("handles clearing non-existing session", func(t *testing.T) { + store := &SessionStore{sessions: map[string]*sessions.SessionState{}} + err := store.Clear("key") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) +} + +func TestSessionStore_loadSessionsFromFile(t *testing.T) { + t.Run("loads sessions from an existing file", func(t *testing.T) { + tmpFile, _ := ioutil.TempFile("", "test_sessions.json") + defer os.Remove(tmpFile.Name()) + + sessions := map[string]*sessions.SessionState{"key": {Email: "test@example.com"}} + data, _ := json.Marshal(sessions) + _ = ioutil.WriteFile(tmpFile.Name(), data, 0600) + + store := &SessionStore{filePath: tmpFile.Name()} + err := store.loadSessionsFromFile() + if err != nil || len(store.sessions) != 1 || store.sessions["key"].Email != "test@example.com" { + t.Fatalf("failed to load sessions from file") + } + }) + + t.Run("initializes empty store if file does not exist", func(t *testing.T) { + store := &SessionStore{filePath: "non_existing_file.json"} + err := store.loadSessionsFromFile() + if err != nil || len(store.sessions) != 0 { + t.Fatalf("expected empty sessions, got error or non-empty sessions") + } + }) +} + +func TestSessionStore_saveSessionsToFile(t *testing.T) { + t.Run("saves sessions to file", func(t *testing.T) { + tmpFile, _ := ioutil.TempFile("", "test_sessions.json") + defer os.Remove(tmpFile.Name()) + + store := &SessionStore{ + filePath: tmpFile.Name(), + sessions: map[string]*sessions.SessionState{"key": {Email: "test@example.com"}}, + } + + err := store.saveSessionsToFile() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, _ := ioutil.ReadFile(tmpFile.Name()) + savedSessions := map[string]*sessions.SessionState{} + json.Unmarshal(data, &savedSessions) + + if savedState, exists := savedSessions["key"]; !exists || savedState.Email != "test@example.com" { + t.Fatalf("session was not saved correctly") + } + }) + + t.Run("handles file write errors gracefully", func(t *testing.T) { + store := &SessionStore{filePath: "/invalid_path/file.json"} + err := store.saveSessionsToFile() + if err == nil { + t.Fatalf("expected error, but got none") + } + }) +} diff --git a/pkg/sessions/session_store.go b/pkg/sessions/session_store.go index 3d4b8d97..37e17dba 100644 --- a/pkg/sessions/session_store.go +++ b/pkg/sessions/session_store.go @@ -6,6 +6,7 @@ import ( "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/cookie" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/file" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/redis" ) @@ -16,6 +17,8 @@ func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.Cookie) ( return cookie.NewCookieSessionStore(opts, cookieOpts) case options.RedisSessionStoreType: return redis.NewRedisSessionStore(opts, cookieOpts) + case options.FileSessionStoreType: + return file.NewFileSessionStore(opts.File) default: return nil, fmt.Errorf("unknown session store type '%s'", opts.Type) }