This commit is contained in:
Adel Salakh 2025-05-15 01:03:08 +02:00
parent eecea5a1be
commit c014b45a87
5 changed files with 49 additions and 54 deletions

17
main.go
View File

@ -1,9 +1,12 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"os/signal"
"runtime" "runtime"
"syscall"
"github.com/ghodss/yaml" "github.com/ghodss/yaml"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
@ -55,12 +58,22 @@ func main() {
validator := NewValidator(opts.EmailDomains, opts.AuthenticatedEmailsFile) validator := NewValidator(opts.EmailDomains, opts.AuthenticatedEmailsFile)
oauthproxy, err := NewOAuthProxy(opts, validator) ctx, cancelFunc := context.WithCancel(context.Background())
// Observe signals in background goroutine.
go func() {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
<-sigint
cancelFunc() // cancel the context
}()
oauthproxy, err := NewOAuthProxy(ctx, opts, validator)
if err != nil { if err != nil {
logger.Fatalf("ERROR: Failed to initialise OAuth2 Proxy: %v", err) logger.Fatalf("ERROR: Failed to initialise OAuth2 Proxy: %v", err)
} }
if err := oauthproxy.Start(); err != nil { if err := oauthproxy.Start(ctx); err != nil {
logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err) logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err)
} }
} }

View File

@ -10,11 +10,8 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os"
"os/signal"
"regexp" "regexp"
"strings" "strings"
"syscall"
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -113,14 +110,11 @@ type OAuthProxy struct {
redirectValidator redirect.Validator redirectValidator redirect.Validator
appDirector redirect.AppDirector appDirector redirect.AppDirector
cancelFunc context.CancelFunc
cancelCtx context.Context
encodeState bool encodeState bool
} }
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided // NewOAuthProxy creates a new instance of OAuthProxy from the options provided
func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthProxy, error) { func NewOAuthProxy(ctx context.Context, opts *options.Options, validator func(string) bool) (*OAuthProxy, error) {
sessionStore, err := sessions.NewSessionStore(&opts.Session, &opts.Cookie) sessionStore, err := sessions.NewSessionStore(&opts.Session, &opts.Cookie)
if err != nil { if err != nil {
return nil, fmt.Errorf("error initialising session store: %v", err) return nil, fmt.Errorf("error initialising session store: %v", err)
@ -203,9 +197,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
return nil, err return nil, err
} }
cancelCtx, cancelFunc := context.WithCancel(context.Background()) preAuthChain, err := buildPreAuthChain(ctx, opts, sessionStore)
preAuthChain, err := buildPreAuthChain(opts, sessionStore, cancelCtx)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not build pre-auth chain: %v", err) return nil, fmt.Errorf("could not build pre-auth chain: %v", err)
} }
@ -253,8 +245,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
redirectValidator: redirectValidator, redirectValidator: redirectValidator,
appDirector: appDirector, appDirector: appDirector,
encodeState: opts.EncodeState, encodeState: opts.EncodeState,
cancelFunc: cancelFunc,
cancelCtx: cancelCtx,
} }
p.buildServeMux(opts.ProxyPrefix) p.buildServeMux(opts.ProxyPrefix)
@ -265,22 +255,14 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
return p, nil return p, nil
} }
func (p *OAuthProxy) Start() error { func (p *OAuthProxy) Start(ctx context.Context) error {
if p.server == nil { if p.server == nil {
// We have to call setupServer before Start is called. // We have to call setupServer before Start is called.
// If this doesn't happen it's a programming error. // If this doesn't happen it's a programming error.
panic("server has not been initialised") panic("server has not been initialised")
} }
// Observe signals in background goroutine. return p.server.Start(ctx)
go func() {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt, syscall.SIGTERM)
<-sigint
p.cancelFunc() // cancel the context
}()
return p.server.Start(p.cancelCtx)
} }
func (p *OAuthProxy) setupServer(opts *options.Options) error { func (p *OAuthProxy) setupServer(opts *options.Options) error {
@ -359,7 +341,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) {
// buildPreAuthChain constructs a chain that should process every request before // buildPreAuthChain constructs a chain that should process every request before
// the OAuth2 Proxy authentication logic kicks in. // the OAuth2 Proxy authentication logic kicks in.
// For example forcing HTTPS or health checks. // For example forcing HTTPS or health checks.
func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionStore, ctx context.Context) (alice.Chain, error) { func buildPreAuthChain(ctx context.Context, opts *options.Options, sessionStore sessionsapi.SessionStore) (alice.Chain, error) {
chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader)) chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader))
if opts.ForceHTTPS { if opts.ForceHTTPS {
@ -380,7 +362,7 @@ func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionSt
// To silence logging of health checks, register the health check handler before // To silence logging of health checks, register the health check handler before
// the logging handler // the logging handler
readinessCheck := middleware.NewReadynessCheck(opts.ReadyPath, sessionStore, ctx) readinessCheck := middleware.NewReadynessCheck(ctx, opts.ReadyPath, sessionStore)
if opts.Logging.SilencePing { if opts.Logging.SilencePing {
chain = chain.Append( chain = chain.Append(
middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents),

View File

@ -47,7 +47,7 @@ func TestRobotsTxt(t *testing.T) {
err := validation.Validate(opts) err := validation.Validate(opts)
assert.NoError(t, err) assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -108,7 +108,7 @@ func Test_redeemCode(t *testing.T) {
err := validation.Validate(opts) err := validation.Validate(opts)
assert.NoError(t, err) assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -160,7 +160,7 @@ func Test_enrichSession(t *testing.T) {
err := validation.Validate(opts) err := validation.Validate(opts)
assert.NoError(t, err) assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -229,7 +229,7 @@ func TestBasicAuthPassword(t *testing.T) {
providerURL, _ := url.Parse(providerServer.URL) providerURL, _ := url.Parse(providerServer.URL)
const emailAddress = "john.doe@example.com" const emailAddress = "john.doe@example.com"
proxy, err := NewOAuthProxy(opts, func(email string) bool { proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool {
return email == emailAddress return email == emailAddress
}) })
if err != nil { if err != nil {
@ -291,7 +291,7 @@ func TestPassGroupsHeadersWithGroups(t *testing.T) {
CreatedAt: &created, CreatedAt: &created,
} }
proxy, err := NewOAuthProxy(opts, func(email string) bool { proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool {
return email == emailAddress return email == emailAddress
}) })
assert.NoError(t, err) assert.NoError(t, err)
@ -387,7 +387,7 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTe
testProvider := NewTestProvider(providerURL, emailAddress) testProvider := NewTestProvider(providerURL, emailAddress)
testProvider.ValidToken = opts.ValidToken testProvider.ValidToken = opts.ValidToken
patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool { patt.proxy, err = NewOAuthProxy(context.Background(), patt.opts, func(email string) bool {
return email == emailAddress return email == emailAddress
}) })
patt.proxy.provider = testProvider patt.proxy.provider = testProvider
@ -592,7 +592,7 @@ func NewSignInPageTest(skipProvider bool) (*SignInPageTest, error) {
return nil, err return nil, err
} }
sipTest.proxy, err = NewOAuthProxy(sipTest.opts, func(email string) bool { sipTest.proxy, err = NewOAuthProxy(context.Background(), sipTest.opts, func(email string) bool {
return true return true
}) })
if err != nil { if err != nil {
@ -628,7 +628,7 @@ func TestManualSignInStoresUserGroupsInTheSession(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
proxy, err := NewOAuthProxy(opts, func(email string) bool { proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool {
return true return true
}) })
if err != nil { if err != nil {
@ -676,7 +676,7 @@ func ManualSignInWithCredentials(t *testing.T, user, pass string) int {
t.Fatal(err) t.Fatal(err)
} }
proxy, err := NewOAuthProxy(opts, func(email string) bool { proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool {
return true return true
}) })
if err != nil { if err != nil {
@ -826,7 +826,7 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi
return nil, err return nil, err
} }
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool {
return pcTest.validateUser return pcTest.validateUser
}) })
if err != nil { if err != nil {
@ -1200,7 +1200,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
err := validation.Validate(pcTest.opts) err := validation.Validate(pcTest.opts)
assert.NoError(t, err) assert.NoError(t, err)
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool {
return pcTest.validateUser return pcTest.validateUser
}) })
if err != nil { if err != nil {
@ -1293,7 +1293,7 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
err := validation.Validate(pcTest.opts) err := validation.Validate(pcTest.opts)
assert.NoError(t, err) assert.NoError(t, err)
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool {
return pcTest.validateUser return pcTest.validateUser
}) })
if err != nil { if err != nil {
@ -1373,7 +1373,7 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
err := validation.Validate(pcTest.opts) err := validation.Validate(pcTest.opts)
assert.NoError(t, err) assert.NoError(t, err)
pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool {
return pcTest.validateUser return pcTest.validateUser
}) })
if err != nil { if err != nil {
@ -1429,7 +1429,7 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) {
upstreamURL, _ := url.Parse(upstreamServer.URL) upstreamURL, _ := url.Parse(upstreamServer.URL)
proxy, err := NewOAuthProxy(opts, func(string) bool { return false }) proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return false })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1554,7 +1554,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) er
if err != nil { if err != nil {
return err return err
} }
proxy, err := NewOAuthProxy(st.opts, func(email string) bool { return true }) proxy, err := NewOAuthProxy(context.Background(), st.opts, func(email string) bool { return true })
if err != nil { if err != nil {
return err return err
} }
@ -1642,7 +1642,7 @@ func newAjaxRequestTest(forceJSONErrors bool) (*ajaxRequestTest, error) {
return nil, err return nil, err
} }
test.proxy, err = NewOAuthProxy(test.opts, func(email string) bool { test.proxy, err = NewOAuthProxy(context.Background(), test.opts, func(email string) bool {
return true return true
}) })
if err != nil { if err != nil {
@ -1958,7 +1958,7 @@ func Test_noCacheHeaders(t *testing.T) {
opts.SkipAuthRegex = []string{".*"} opts.SkipAuthRegex = []string{".*"}
err := validation.Validate(opts) err := validation.Validate(opts)
assert.NoError(t, err) assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -2233,7 +2233,7 @@ func TestTrustedIPs(t *testing.T) {
err := validation.Validate(opts) err := validation.Validate(opts)
assert.NoError(t, err) assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true })
assert.NoError(t, err) assert.NoError(t, err)
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
@ -2478,7 +2478,7 @@ func TestApiRoutes(t *testing.T) {
opts.SkipProviderButton = true opts.SkipProviderButton = true
err := validation.Validate(opts) err := validation.Validate(opts)
assert.NoError(t, err) assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -2561,7 +2561,7 @@ func TestAllowedRequest(t *testing.T) {
} }
err := validation.Validate(opts) err := validation.Validate(opts)
assert.NoError(t, err) assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -2660,7 +2660,7 @@ func TestAllowedRequestWithForwardedUriHeader(t *testing.T) {
} }
err := validation.Validate(opts) err := validation.Validate(opts)
assert.NoError(t, err) assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -2759,7 +2759,7 @@ func TestAllowedRequestNegateWithoutMethod(t *testing.T) {
} }
err := validation.Validate(opts) err := validation.Validate(opts)
assert.NoError(t, err) assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -2859,7 +2859,7 @@ func TestAllowedRequestNegateWithMethod(t *testing.T) {
} }
err := validation.Validate(opts) err := validation.Validate(opts)
assert.NoError(t, err) assert.NoError(t, err)
proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -3460,7 +3460,7 @@ func TestGetOAuthRedirectURI(t *testing.T) {
err := validation.Validate(baseOpts) err := validation.Validate(baseOpts)
assert.NoError(t, err) assert.NoError(t, err)
proxy, err := NewOAuthProxy(tt.setupOpts(baseOpts), func(string) bool { return true }) proxy, err := NewOAuthProxy(context.Background(), tt.setupOpts(baseOpts), func(string) bool { return true })
assert.NoError(t, err) assert.NoError(t, err)
assert.Equalf(t, tt.want, proxy.getOAuthRedirectURI(tt.req), "getOAuthRedirectURI(%v)", tt.req) assert.Equalf(t, tt.want, proxy.getOAuthRedirectURI(tt.req), "getOAuthRedirectURI(%v)", tt.req)

View File

@ -16,13 +16,13 @@ type Verifiable interface {
// NewReadynessCheck returns a middleware that performs deep health checks // NewReadynessCheck returns a middleware that performs deep health checks
// (verifies the connection to any underlying store) on a specific `path` // (verifies the connection to any underlying store) on a specific `path`
func NewReadynessCheck(path string, verifiable Verifiable, ctx context.Context) alice.Constructor { func NewReadynessCheck(ctx context.Context, path string, verifiable Verifiable) alice.Constructor {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return readynessCheck(path, verifiable, next, ctx) return readynessCheck(ctx, path, verifiable, next)
} }
} }
func readynessCheck(path string, verifiable Verifiable, next http.Handler, ctx context.Context) http.Handler { func readynessCheck(ctx context.Context, path string, verifiable Verifiable, next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if ctx.Err() != nil { if ctx.Err() != nil {
rw.WriteHeader(http.StatusServiceUnavailable) rw.WriteHeader(http.StatusServiceUnavailable)

View File

@ -27,7 +27,7 @@ var _ = Describe("ReadynessCheck suite", func() {
ctx := context.Background() ctx := context.Background()
handler := NewReadynessCheck(in.readyPath, in.healthVerifiable, ctx)(http.NotFoundHandler()) handler := NewReadynessCheck(ctx, in.readyPath, in.healthVerifiable)(http.NotFoundHandler())
handler.ServeHTTP(rw, req) handler.ServeHTTP(rw, req)
Expect(rw.Code).To(Equal(in.expectedStatus)) Expect(rw.Code).To(Equal(in.expectedStatus))