Merge pull request #1047 from oauth2-proxy/http-server
Refactor HTTP Server and add ServerGroup to handle graceful shutdown of multiple servers
This commit is contained in:
		
						commit
						6894738d97
					
				|  | @ -8,6 +8,7 @@ | ||||||
| 
 | 
 | ||||||
| ## Changes since v7.0.1 | ## Changes since v7.0.1 | ||||||
| 
 | 
 | ||||||
|  | - [#1047](https://github.com/oauth2-proxy/oauth2-proxy/pull/1047) Refactor HTTP Server and add ServerGroup to handle graceful shutdown of multiple servers (@JoelSpeed) | ||||||
| - [#1070](https://github.com/oauth2-proxy/oauth2-proxy/pull/1070) Refactor logging middleware to middleware package (@NickMeves) | - [#1070](https://github.com/oauth2-proxy/oauth2-proxy/pull/1070) Refactor logging middleware to middleware package (@NickMeves) | ||||||
| - [#1064](https://github.com/oauth2-proxy/oauth2-proxy/pull/1064) Add support for setting groups on session when using basic auth (@stefansedich) | - [#1064](https://github.com/oauth2-proxy/oauth2-proxy/pull/1064) Add support for setting groups on session when using basic auth (@stefansedich) | ||||||
| - [#1056](https://github.com/oauth2-proxy/oauth2-proxy/pull/1056) Add option for custom logos on the sign in page (@JoelSpeed) | - [#1056](https://github.com/oauth2-proxy/oauth2-proxy/pull/1056) Add option for custom logos on the sign in page (@JoelSpeed) | ||||||
|  |  | ||||||
|  | @ -117,6 +117,8 @@ They may change between releases without notice. | ||||||
| | `upstreams` | _[Upstreams](#upstreams)_ | Upstreams is used to configure upstream servers.<br/>Once a user is authenticated, requests to the server will be proxied to<br/>these upstream servers based on the path mappings defined in this list. | | | `upstreams` | _[Upstreams](#upstreams)_ | Upstreams is used to configure upstream servers.<br/>Once a user is authenticated, requests to the server will be proxied to<br/>these upstream servers based on the path mappings defined in this list. | | ||||||
| | `injectRequestHeaders` | _[[]Header](#header)_ | InjectRequestHeaders is used to configure headers that should be added<br/>to requests to upstream servers.<br/>Headers may source values from either the authenticated user's session<br/>or from a static secret value. | | | `injectRequestHeaders` | _[[]Header](#header)_ | InjectRequestHeaders is used to configure headers that should be added<br/>to requests to upstream servers.<br/>Headers may source values from either the authenticated user's session<br/>or from a static secret value. | | ||||||
| | `injectResponseHeaders` | _[[]Header](#header)_ | InjectResponseHeaders is used to configure headers that should be added<br/>to responses from the proxy.<br/>This is typically used when using the proxy as an external authentication<br/>provider in conjunction with another proxy such as NGINX and its<br/>auth_request module.<br/>Headers may source values from either the authenticated user's session<br/>or from a static secret value. | | | `injectResponseHeaders` | _[[]Header](#header)_ | InjectResponseHeaders is used to configure headers that should be added<br/>to responses from the proxy.<br/>This is typically used when using the proxy as an external authentication<br/>provider in conjunction with another proxy such as NGINX and its<br/>auth_request module.<br/>Headers may source values from either the authenticated user's session<br/>or from a static secret value. | | ||||||
|  | | `server` | _[Server](#server)_ | Server is used to configure the HTTP(S) server for the proxy application.<br/>You may choose to run both HTTP and HTTPS servers simultaneously.<br/>This can be done by setting the BindAddress and the SecureBindAddress simultaneously.<br/>To use the secure server you must configure a TLS certificate and key. | | ||||||
|  | | `metricsServer` | _[Server](#server)_ | MetricsServer is used to configure the HTTP(S) server for metrics.<br/>You may choose to run both HTTP and HTTPS servers simultaneously.<br/>This can be done by setting the BindAddress and the SecureBindAddress simultaneously.<br/>To use the secure server you must configure a TLS certificate and key. | | ||||||
| 
 | 
 | ||||||
| ### ClaimSource | ### ClaimSource | ||||||
| 
 | 
 | ||||||
|  | @ -172,7 +174,7 @@ make up the header value | ||||||
| 
 | 
 | ||||||
| ### SecretSource | ### SecretSource | ||||||
| 
 | 
 | ||||||
| (**Appears on:** [ClaimSource](#claimsource), [HeaderValue](#headervalue)) | (**Appears on:** [ClaimSource](#claimsource), [HeaderValue](#headervalue), [TLS](#tls)) | ||||||
| 
 | 
 | ||||||
| SecretSource references an individual secret value. | SecretSource references an individual secret value. | ||||||
| Only one source within the struct should be defined at any time. | Only one source within the struct should be defined at any time. | ||||||
|  | @ -183,6 +185,29 @@ Only one source within the struct should be defined at any time. | ||||||
| | `fromEnv` | _string_ | FromEnv expects the name of an environment variable. | | | `fromEnv` | _string_ | FromEnv expects the name of an environment variable. | | ||||||
| | `fromFile` | _string_ | FromFile expects a path to a file containing the secret value. | | | `fromFile` | _string_ | FromFile expects a path to a file containing the secret value. | | ||||||
| 
 | 
 | ||||||
|  | ### Server | ||||||
|  | 
 | ||||||
|  | (**Appears on:** [AlphaOptions](#alphaoptions)) | ||||||
|  | 
 | ||||||
|  | Server represents the configuration for an HTTP(S) server | ||||||
|  | 
 | ||||||
|  | | Field | Type | Description | | ||||||
|  | | ----- | ---- | ----------- | | ||||||
|  | | `BindAddress` | _string_ | BindAddress is the the address on which to serve traffic.<br/>Leave blank or set to "-" to disable. | | ||||||
|  | | `SecureBindAddress` | _string_ | SecureBindAddress is the the address on which to serve secure traffic.<br/>Leave blank or set to "-" to disable. | | ||||||
|  | | `TLS` | _[TLS](#tls)_ | TLS contains the information for loading the certificate and key for the<br/>secure traffic. | | ||||||
|  | 
 | ||||||
|  | ### TLS | ||||||
|  | 
 | ||||||
|  | (**Appears on:** [Server](#server)) | ||||||
|  | 
 | ||||||
|  | TLS contains the information for loading a TLS certifcate and key. | ||||||
|  | 
 | ||||||
|  | | Field | Type | Description | | ||||||
|  | | ----- | ---- | ----------- | | ||||||
|  | | `Key` | _[SecretSource](#secretsource)_ | Key is the the TLS key data to use.<br/>Typically this will come from a file. | | ||||||
|  | | `Cert` | _[SecretSource](#secretsource)_ | Cert is the TLS certificate data to use.<br/>Typically this will come from a file. | | ||||||
|  | 
 | ||||||
| ### Upstream | ### Upstream | ||||||
| 
 | 
 | ||||||
| (**Appears on:** [Upstreams](#upstreams)) | (**Appears on:** [Upstreams](#upstreams)) | ||||||
|  |  | ||||||
							
								
								
									
										1
									
								
								go.mod
								
								
								
								
							
							
						
						
									
										1
									
								
								go.mod
								
								
								
								
							|  | @ -30,6 +30,7 @@ require ( | ||||||
| 	golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 | 	golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 | ||||||
| 	golang.org/x/net v0.0.0-20200707034311-ab3426394381 | 	golang.org/x/net v0.0.0-20200707034311-ab3426394381 | ||||||
| 	golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d | 	golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d | ||||||
|  | 	golang.org/x/sync v0.0.0-20201207232520-09787c993a3a | ||||||
| 	google.golang.org/api v0.20.0 | 	google.golang.org/api v0.20.0 | ||||||
| 	gopkg.in/natefinch/lumberjack.v2 v2.0.0 | 	gopkg.in/natefinch/lumberjack.v2 v2.0.0 | ||||||
| 	gopkg.in/square/go-jose.v2 v2.4.1 | 	gopkg.in/square/go-jose.v2 v2.4.1 | ||||||
|  |  | ||||||
							
								
								
									
										3
									
								
								go.sum
								
								
								
								
							
							
						
						
									
										3
									
								
								go.sum
								
								
								
								
							|  | @ -506,7 +506,10 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ | ||||||
| golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||||
| golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||||
| golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||||
|  | golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= | ||||||
| golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||||
|  | golang.org/x/sync v0.0.0-20201207232520-09787c993a3a h1:DcqTD9SDLc+1P/r1EmRBwnVsrOwW+kk2vWf9n+1sGhs= | ||||||
|  | golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||||
| golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||||
| golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||||
| golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||||
|  |  | ||||||
							
								
								
									
										136
									
								
								http.go
								
								
								
								
							
							
						
						
									
										136
									
								
								http.go
								
								
								
								
							|  | @ -1,136 +0,0 @@ | ||||||
| package main |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"context" |  | ||||||
| 	"crypto/tls" |  | ||||||
| 	"errors" |  | ||||||
| 	"net" |  | ||||||
| 	"net/http" |  | ||||||
| 	"strings" |  | ||||||
| 	"time" |  | ||||||
| 
 |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| // Server represents an HTTP server
 |  | ||||||
| type Server struct { |  | ||||||
| 	Handler http.Handler |  | ||||||
| 	Opts    *options.Options |  | ||||||
| 	stop    chan struct{} // channel for waiting shutdown
 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ListenAndServe will serve traffic on HTTP or HTTPS depending on TLS options
 |  | ||||||
| func (s *Server) ListenAndServe() { |  | ||||||
| 	if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" { |  | ||||||
| 		s.ServeHTTPS() |  | ||||||
| 	} else { |  | ||||||
| 		s.ServeHTTP() |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ServeHTTP constructs a net.Listener and starts handling HTTP requests
 |  | ||||||
| func (s *Server) ServeHTTP() { |  | ||||||
| 	HTTPAddress := s.Opts.HTTPAddress |  | ||||||
| 	var scheme string |  | ||||||
| 
 |  | ||||||
| 	i := strings.Index(HTTPAddress, "://") |  | ||||||
| 	if i > -1 { |  | ||||||
| 		scheme = HTTPAddress[0:i] |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	var networkType string |  | ||||||
| 	switch scheme { |  | ||||||
| 	case "", "http": |  | ||||||
| 		networkType = "tcp" |  | ||||||
| 	default: |  | ||||||
| 		networkType = scheme |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	slice := strings.SplitN(HTTPAddress, "//", 2) |  | ||||||
| 	listenAddr := slice[len(slice)-1] |  | ||||||
| 
 |  | ||||||
| 	listener, err := net.Listen(networkType, listenAddr) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Fatalf("FATAL: listen (%s, %s) failed - %s", networkType, listenAddr, err) |  | ||||||
| 	} |  | ||||||
| 	logger.Printf("HTTP: listening on %s", listenAddr) |  | ||||||
| 	s.serve(listener) |  | ||||||
| 	logger.Printf("HTTP: closing %s", listener.Addr()) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
 |  | ||||||
| func (s *Server) ServeHTTPS() { |  | ||||||
| 	addr := s.Opts.HTTPSAddress |  | ||||||
| 	config := &tls.Config{ |  | ||||||
| 		MinVersion: tls.VersionTLS12, |  | ||||||
| 		MaxVersion: tls.VersionTLS13, |  | ||||||
| 	} |  | ||||||
| 	if config.NextProtos == nil { |  | ||||||
| 		config.NextProtos = []string{"http/1.1"} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	var err error |  | ||||||
| 	config.Certificates = make([]tls.Certificate, 1) |  | ||||||
| 	config.Certificates[0], err = tls.LoadX509KeyPair(s.Opts.TLSCertFile, s.Opts.TLSKeyFile) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Fatalf("FATAL: loading tls config (%s, %s) failed - %s", s.Opts.TLSCertFile, s.Opts.TLSKeyFile, err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	ln, err := net.Listen("tcp", addr) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Fatalf("FATAL: listen (%s) failed - %s", addr, err) |  | ||||||
| 	} |  | ||||||
| 	logger.Printf("HTTPS: listening on %s", ln.Addr()) |  | ||||||
| 
 |  | ||||||
| 	tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config) |  | ||||||
| 	s.serve(tlsListener) |  | ||||||
| 	logger.Printf("HTTPS: closing %s", tlsListener.Addr()) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (s *Server) serve(listener net.Listener) { |  | ||||||
| 	srv := &http.Server{Handler: s.Handler} |  | ||||||
| 
 |  | ||||||
| 	// See https://golang.org/pkg/net/http/#Server.Shutdown
 |  | ||||||
| 	idleConnsClosed := make(chan struct{}) |  | ||||||
| 	go func() { |  | ||||||
| 		<-s.stop // wait notification for stopping server
 |  | ||||||
| 
 |  | ||||||
| 		// We received an interrupt signal, shut down.
 |  | ||||||
| 		if err := srv.Shutdown(context.Background()); err != nil { |  | ||||||
| 			// Error from closing listeners, or context timeout:
 |  | ||||||
| 			logger.Printf("HTTP server Shutdown: %v", err) |  | ||||||
| 		} |  | ||||||
| 		close(idleConnsClosed) |  | ||||||
| 	}() |  | ||||||
| 
 |  | ||||||
| 	err := srv.Serve(listener) |  | ||||||
| 	if err != nil && !errors.Is(err, http.ErrServerClosed) { |  | ||||||
| 		logger.Errorf("ERROR: http.Serve() - %s", err) |  | ||||||
| 	} |  | ||||||
| 	<-idleConnsClosed |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
 |  | ||||||
| // connections. It's used by ListenAndServe and ListenAndServeTLS so
 |  | ||||||
| // dead TCP connections (e.g. closing laptop mid-download) eventually
 |  | ||||||
| // go away.
 |  | ||||||
| type tcpKeepAliveListener struct { |  | ||||||
| 	*net.TCPListener |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (ln tcpKeepAliveListener) Accept() (net.Conn, error) { |  | ||||||
| 	tc, err := ln.AcceptTCP() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	err = tc.SetKeepAlive(true) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Printf("Error setting Keep-Alive: %v", err) |  | ||||||
| 	} |  | ||||||
| 	err = tc.SetKeepAlivePeriod(3 * time.Minute) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Printf("Error setting Keep-Alive period: %v", err) |  | ||||||
| 	} |  | ||||||
| 	return tc, nil |  | ||||||
| } |  | ||||||
							
								
								
									
										39
									
								
								http_test.go
								
								
								
								
							
							
						
						
									
										39
									
								
								http_test.go
								
								
								
								
							|  | @ -1,39 +0,0 @@ | ||||||
| package main |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"net/http" |  | ||||||
| 	"sync" |  | ||||||
| 	"testing" |  | ||||||
| 	"time" |  | ||||||
| 
 |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" |  | ||||||
| 	"github.com/stretchr/testify/assert" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| func TestGracefulShutdown(t *testing.T) { |  | ||||||
| 	opts := options.NewOptions() |  | ||||||
| 	stop := make(chan struct{}, 1) |  | ||||||
| 	srv := Server{Handler: http.DefaultServeMux, Opts: opts, stop: stop} |  | ||||||
| 	var wg sync.WaitGroup |  | ||||||
| 	wg.Add(1) |  | ||||||
| 	go func() { |  | ||||||
| 		defer wg.Done() |  | ||||||
| 		srv.ServeHTTP() |  | ||||||
| 	}() |  | ||||||
| 
 |  | ||||||
| 	stop <- struct{}{} // emulate catching signals
 |  | ||||||
| 
 |  | ||||||
| 	// An idiomatic for sync.WaitGroup with timeout
 |  | ||||||
| 	c := make(chan struct{}) |  | ||||||
| 	go func() { |  | ||||||
| 		defer close(c) |  | ||||||
| 		wg.Wait() |  | ||||||
| 	}() |  | ||||||
| 	select { |  | ||||||
| 	case <-c: |  | ||||||
| 	case <-time.After(1 * time.Second): |  | ||||||
| 		t.Fatal("Server should return gracefully but timeout has occurred") |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	assert.Len(t, stop, 0) // check if stop chan is empty
 |  | ||||||
| } |  | ||||||
							
								
								
									
										54
									
								
								main.go
								
								
								
								
							
							
						
						
									
										54
									
								
								main.go
								
								
								
								
							|  | @ -1,20 +1,15 @@ | ||||||
| package main | package main | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" |  | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"math/rand" | 	"math/rand" | ||||||
| 	"net/http" |  | ||||||
| 	"os" | 	"os" | ||||||
| 	"os/signal" |  | ||||||
| 	"runtime" | 	"runtime" | ||||||
| 	"syscall" |  | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/ghodss/yaml" | 	"github.com/ghodss/yaml" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation" | ||||||
| 	"github.com/spf13/pflag" | 	"github.com/spf13/pflag" | ||||||
| ) | ) | ||||||
|  | @ -67,54 +62,9 @@ func main() { | ||||||
| 
 | 
 | ||||||
| 	rand.Seed(time.Now().UnixNano()) | 	rand.Seed(time.Now().UnixNano()) | ||||||
| 
 | 
 | ||||||
| 	oauthProxyStop := make(chan struct{}, 1) | 	if err := oauthproxy.Start(); err != nil { | ||||||
| 	metricsStop := startMetricsServer(opts.MetricsAddress, oauthProxyStop) | 		logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err) | ||||||
| 
 |  | ||||||
| 	s := &Server{ |  | ||||||
| 		Handler: oauthproxy, |  | ||||||
| 		Opts:    opts, |  | ||||||
| 		stop:    oauthProxyStop, |  | ||||||
| 	} | 	} | ||||||
| 	// Observe signals in background goroutine.
 |  | ||||||
| 	go func() { |  | ||||||
| 		sigint := make(chan os.Signal, 1) |  | ||||||
| 		signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) |  | ||||||
| 		<-sigint |  | ||||||
| 		s.stop <- struct{}{} // notify having caught signal stop oauthproxy
 |  | ||||||
| 		close(metricsStop)   // and the metrics endpoint
 |  | ||||||
| 	}() |  | ||||||
| 	s.ListenAndServe() |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // startMetricsServer will start the metrics server on the specified address.
 |  | ||||||
| // It always return a channel to signal stop even when it does not run.
 |  | ||||||
| func startMetricsServer(address string, oauthProxyStop chan struct{}) chan struct{} { |  | ||||||
| 	stop := make(chan struct{}, 1) |  | ||||||
| 
 |  | ||||||
| 	// Attempt to setup the metrics endpoint if we have an address
 |  | ||||||
| 	if address != "" { |  | ||||||
| 		s := &http.Server{Addr: address, Handler: middleware.DefaultMetricsHandler} |  | ||||||
| 		go func() { |  | ||||||
| 			// ListenAndServe always returns a non-nil error. After Shutdown or
 |  | ||||||
| 			// Close, the returned error is ErrServerClosed
 |  | ||||||
| 			if err := s.ListenAndServe(); err != http.ErrServerClosed { |  | ||||||
| 				logger.Println(err) |  | ||||||
| 				// Stop the metrics shutdown go routine
 |  | ||||||
| 				close(stop) |  | ||||||
| 				// Stop the oauthproxy server, we have encounter an unexpected error
 |  | ||||||
| 				close(oauthProxyStop) |  | ||||||
| 			} |  | ||||||
| 		}() |  | ||||||
| 
 |  | ||||||
| 		go func() { |  | ||||||
| 			<-stop |  | ||||||
| 			if err := s.Shutdown(context.Background()); err != nil { |  | ||||||
| 				logger.Print(err) |  | ||||||
| 			} |  | ||||||
| 		}() |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return stop |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // loadConfiguration will load in the user's configuration.
 | // loadConfiguration will load in the user's configuration.
 | ||||||
|  |  | ||||||
|  | @ -15,6 +15,7 @@ import ( | ||||||
| 
 | 
 | ||||||
| var _ = Describe("Configuration Loading Suite", func() { | var _ = Describe("Configuration Loading Suite", func() { | ||||||
| 	const testLegacyConfig = ` | 	const testLegacyConfig = ` | ||||||
|  | http_address="127.0.0.1:4180" | ||||||
| upstreams="http://httpbin" | upstreams="http://httpbin" | ||||||
| set_basic_auth="true" | set_basic_auth="true" | ||||||
| basic_auth_password="super-secret-password" | basic_auth_password="super-secret-password" | ||||||
|  | @ -54,10 +55,11 @@ injectResponseHeaders: | ||||||
|     prefix: "Basic " |     prefix: "Basic " | ||||||
|     basicAuthPassword: |     basicAuthPassword: | ||||||
|       value: c3VwZXItc2VjcmV0LXBhc3N3b3Jk |       value: c3VwZXItc2VjcmV0LXBhc3N3b3Jk | ||||||
|  | server: | ||||||
|  |   bindAddress: "127.0.0.1:4180" | ||||||
| ` | ` | ||||||
| 
 | 
 | ||||||
| 	const testCoreConfig = ` | 	const testCoreConfig = ` | ||||||
| http_address="0.0.0.0:4180" |  | ||||||
| cookie_secret="OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w=" | cookie_secret="OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w=" | ||||||
| provider="oidc" | provider="oidc" | ||||||
| email_domains="example.com" | email_domains="example.com" | ||||||
|  | @ -82,7 +84,6 @@ redirect_url="http://localhost:4180/oauth2/callback" | ||||||
| 		opts, err := options.NewLegacyOptions().ToOptions() | 		opts, err := options.NewLegacyOptions().ToOptions() | ||||||
| 		Expect(err).ToNot(HaveOccurred()) | 		Expect(err).ToNot(HaveOccurred()) | ||||||
| 
 | 
 | ||||||
| 		opts.HTTPAddress = "0.0.0.0:4180" |  | ||||||
| 		opts.Cookie.Secret = "OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w=" | 		opts.Cookie.Secret = "OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w=" | ||||||
| 		opts.ProviderType = "oidc" | 		opts.ProviderType = "oidc" | ||||||
| 		opts.EmailDomains = []string{"example.com"} | 		opts.EmailDomains = []string{"example.com"} | ||||||
|  | @ -203,7 +204,7 @@ redirect_url="http://localhost:4180/oauth2/callback" | ||||||
| 			configContent:      testCoreConfig, | 			configContent:      testCoreConfig, | ||||||
| 			alphaConfigContent: testAlphaConfig + ":", | 			alphaConfigContent: testAlphaConfig + ":", | ||||||
| 			expectedOptions:    func() *options.Options { return nil }, | 			expectedOptions:    func() *options.Options { return nil }, | ||||||
| 			expectedErr:        errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 34: did not find expected key"), | 			expectedErr:        errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 36: did not find expected key"), | ||||||
| 		}), | 		}), | ||||||
| 		Entry("with alpha configuration and bad core configuration", loadConfigurationTableInput{ | 		Entry("with alpha configuration and bad core configuration", loadConfigurationTableInput{ | ||||||
| 			configContent:      testCoreConfig + "unknown_field=\"something\"", | 			configContent:      testCoreConfig + "unknown_field=\"something\"", | ||||||
|  |  | ||||||
|  | @ -8,8 +8,11 @@ import ( | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
|  | 	"os" | ||||||
|  | 	"os/signal" | ||||||
| 	"regexp" | 	"regexp" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 	"syscall" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/justinas/alice" | 	"github.com/justinas/alice" | ||||||
|  | @ -21,6 +24,7 @@ import ( | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" | ||||||
|  | 	proxyhttp "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/http" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||||
|  | @ -102,6 +106,7 @@ type OAuthProxy struct { | ||||||
| 	headersChain alice.Chain | 	headersChain alice.Chain | ||||||
| 	preAuthChain alice.Chain | 	preAuthChain alice.Chain | ||||||
| 	pageWriter   pagewriter.Writer | 	pageWriter   pagewriter.Writer | ||||||
|  | 	server       proxyhttp.Server | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // NewOAuthProxy creates a new instance of OAuthProxy from the options provided
 | // NewOAuthProxy creates a new instance of OAuthProxy from the options provided
 | ||||||
|  | @ -184,7 +189,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		return nil, fmt.Errorf("could not build headers chain: %v", err) | 		return nil, fmt.Errorf("could not build headers chain: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &OAuthProxy{ | 	p := &OAuthProxy{ | ||||||
| 		CookieName:     opts.Cookie.Name, | 		CookieName:     opts.Cookie.Name, | ||||||
| 		CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"), | 		CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"), | ||||||
| 		CookieSeed:     opts.Cookie.Secret, | 		CookieSeed:     opts.Cookie.Secret, | ||||||
|  | @ -223,7 +228,60 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		headersChain:       headersChain, | 		headersChain:       headersChain, | ||||||
| 		preAuthChain:       preAuthChain, | 		preAuthChain:       preAuthChain, | ||||||
| 		pageWriter:         pageWriter, | 		pageWriter:         pageWriter, | ||||||
| 	}, nil | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := p.setupServer(opts); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("error setting up server: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return p, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *OAuthProxy) Start() error { | ||||||
|  | 	if p.server == nil { | ||||||
|  | 		// We have to call setupServer before Start is called.
 | ||||||
|  | 		// If this doesn't happen it's a programming error.
 | ||||||
|  | 		panic("server has not been initialised") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	ctx, cancel := context.WithCancel(context.Background()) | ||||||
|  | 
 | ||||||
|  | 	// Observe signals in background goroutine.
 | ||||||
|  | 	go func() { | ||||||
|  | 		sigint := make(chan os.Signal, 1) | ||||||
|  | 		signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) | ||||||
|  | 		<-sigint | ||||||
|  | 		cancel() // cancel the context
 | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	return p.server.Start(ctx) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (p *OAuthProxy) setupServer(opts *options.Options) error { | ||||||
|  | 	serverOpts := proxyhttp.Opts{ | ||||||
|  | 		Handler:           p, | ||||||
|  | 		BindAddress:       opts.Server.BindAddress, | ||||||
|  | 		SecureBindAddress: opts.Server.SecureBindAddress, | ||||||
|  | 		TLS:               opts.Server.TLS, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	appServer, err := proxyhttp.NewServer(serverOpts) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("could not build app server: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	metricsServer, err := proxyhttp.NewServer(proxyhttp.Opts{ | ||||||
|  | 		Handler:           middleware.DefaultMetricsHandler, | ||||||
|  | 		BindAddress:       opts.MetricsServer.BindAddress, | ||||||
|  | 		SecureBindAddress: opts.MetricsServer.BindAddress, | ||||||
|  | 		TLS:               opts.MetricsServer.TLS, | ||||||
|  | 	}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("could not build metrics server: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	p.server = proxyhttp.NewServerGroup(appServer, metricsServer) | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // buildPreAuthChain constructs a chain that should process every request before
 | // buildPreAuthChain constructs a chain that should process every request before
 | ||||||
|  | @ -233,9 +291,9 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { | ||||||
| 	chain := alice.New(middleware.NewScope(opts.ReverseProxy)) | 	chain := alice.New(middleware.NewScope(opts.ReverseProxy)) | ||||||
| 
 | 
 | ||||||
| 	if opts.ForceHTTPS { | 	if opts.ForceHTTPS { | ||||||
| 		_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress) | 		_, httpsPort, err := net.SplitHostPort(opts.Server.SecureBindAddress) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.HTTPAddress, err) | 			return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.Server.SecureBindAddress, err) | ||||||
| 		} | 		} | ||||||
| 		chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort)) | 		chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort)) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -2341,6 +2341,7 @@ func baseTestOptions() *options.Options { | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	return opts | 	return opts | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -28,6 +28,18 @@ type AlphaOptions struct { | ||||||
| 	// Headers may source values from either the authenticated user's session
 | 	// Headers may source values from either the authenticated user's session
 | ||||||
| 	// or from a static secret value.
 | 	// or from a static secret value.
 | ||||||
| 	InjectResponseHeaders []Header `json:"injectResponseHeaders,omitempty"` | 	InjectResponseHeaders []Header `json:"injectResponseHeaders,omitempty"` | ||||||
|  | 
 | ||||||
|  | 	// Server is used to configure the HTTP(S) server for the proxy application.
 | ||||||
|  | 	// You may choose to run both HTTP and HTTPS servers simultaneously.
 | ||||||
|  | 	// This can be done by setting the BindAddress and the SecureBindAddress simultaneously.
 | ||||||
|  | 	// To use the secure server you must configure a TLS certificate and key.
 | ||||||
|  | 	Server Server `json:"server,omitempty"` | ||||||
|  | 
 | ||||||
|  | 	// MetricsServer is used to configure the HTTP(S) server for metrics.
 | ||||||
|  | 	// You may choose to run both HTTP and HTTPS servers simultaneously.
 | ||||||
|  | 	// This can be done by setting the BindAddress and the SecureBindAddress simultaneously.
 | ||||||
|  | 	// To use the secure server you must configure a TLS certificate and key.
 | ||||||
|  | 	MetricsServer Server `json:"metricsServer,omitempty"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // MergeInto replaces alpha options in the Options struct with the values
 | // MergeInto replaces alpha options in the Options struct with the values
 | ||||||
|  | @ -36,6 +48,8 @@ func (a *AlphaOptions) MergeInto(opts *Options) { | ||||||
| 	opts.UpstreamServers = a.Upstreams | 	opts.UpstreamServers = a.Upstreams | ||||||
| 	opts.InjectRequestHeaders = a.InjectRequestHeaders | 	opts.InjectRequestHeaders = a.InjectRequestHeaders | ||||||
| 	opts.InjectResponseHeaders = a.InjectResponseHeaders | 	opts.InjectResponseHeaders = a.InjectResponseHeaders | ||||||
|  | 	opts.Server = a.Server | ||||||
|  | 	opts.MetricsServer = a.MetricsServer | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ExtractFrom populates the fields in the AlphaOptions with the values from
 | // ExtractFrom populates the fields in the AlphaOptions with the values from
 | ||||||
|  | @ -44,4 +58,6 @@ func (a *AlphaOptions) ExtractFrom(opts *Options) { | ||||||
| 	a.Upstreams = opts.UpstreamServers | 	a.Upstreams = opts.UpstreamServers | ||||||
| 	a.InjectRequestHeaders = opts.InjectRequestHeaders | 	a.InjectRequestHeaders = opts.InjectRequestHeaders | ||||||
| 	a.InjectResponseHeaders = opts.InjectResponseHeaders | 	a.InjectResponseHeaders = opts.InjectResponseHeaders | ||||||
|  | 	a.Server = opts.Server | ||||||
|  | 	a.MetricsServer = opts.MetricsServer | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -18,6 +18,9 @@ type LegacyOptions struct { | ||||||
| 	// Legacy options for injecting request/response headers
 | 	// Legacy options for injecting request/response headers
 | ||||||
| 	LegacyHeaders LegacyHeaders `cfg:",squash"` | 	LegacyHeaders LegacyHeaders `cfg:",squash"` | ||||||
| 
 | 
 | ||||||
|  | 	// Legacy options for the server address and TLS
 | ||||||
|  | 	LegacyServer LegacyServer `cfg:",squash"` | ||||||
|  | 
 | ||||||
| 	Options Options `cfg:",squash"` | 	Options Options `cfg:",squash"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -35,6 +38,11 @@ func NewLegacyOptions() *LegacyOptions { | ||||||
| 			SkipAuthStripHeaders: true, | 			SkipAuthStripHeaders: true, | ||||||
| 		}, | 		}, | ||||||
| 
 | 
 | ||||||
|  | 		LegacyServer: LegacyServer{ | ||||||
|  | 			HTTPAddress:  "127.0.0.1:4180", | ||||||
|  | 			HTTPSAddress: ":443", | ||||||
|  | 		}, | ||||||
|  | 
 | ||||||
| 		Options: *NewOptions(), | 		Options: *NewOptions(), | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | @ -44,6 +52,7 @@ func NewLegacyFlagSet() *pflag.FlagSet { | ||||||
| 
 | 
 | ||||||
| 	flagSet.AddFlagSet(legacyUpstreamsFlagSet()) | 	flagSet.AddFlagSet(legacyUpstreamsFlagSet()) | ||||||
| 	flagSet.AddFlagSet(legacyHeadersFlagSet()) | 	flagSet.AddFlagSet(legacyHeadersFlagSet()) | ||||||
|  | 	flagSet.AddFlagSet(legacyServerFlagset()) | ||||||
| 
 | 
 | ||||||
| 	return flagSet | 	return flagSet | ||||||
| } | } | ||||||
|  | @ -56,6 +65,8 @@ func (l *LegacyOptions) ToOptions() (*Options, error) { | ||||||
| 	l.Options.UpstreamServers = upstreams | 	l.Options.UpstreamServers = upstreams | ||||||
| 
 | 
 | ||||||
| 	l.Options.InjectRequestHeaders, l.Options.InjectResponseHeaders = l.LegacyHeaders.convert() | 	l.Options.InjectRequestHeaders, l.Options.InjectResponseHeaders = l.LegacyHeaders.convert() | ||||||
|  | 	l.Options.Server, l.Options.MetricsServer = l.LegacyServer.convert() | ||||||
|  | 
 | ||||||
| 	return &l.Options, nil | 	return &l.Options, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -403,3 +414,69 @@ func getXAuthRequestAccessTokenHeader() Header { | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | type LegacyServer struct { | ||||||
|  | 	MetricsAddress       string `flag:"metrics-address" cfg:"metrics_address"` | ||||||
|  | 	MetricsSecureAddress string `flag:"metrics-secure-address" cfg:"metrics_address"` | ||||||
|  | 	MetricsTLSCertFile   string `flag:"metrics-tls-cert-file" cfg:"tls_cert_file"` | ||||||
|  | 	MetricsTLSKeyFile    string `flag:"metrics-tls-key-file" cfg:"tls_key_file"` | ||||||
|  | 	HTTPAddress          string `flag:"http-address" cfg:"http_address"` | ||||||
|  | 	HTTPSAddress         string `flag:"https-address" cfg:"https_address"` | ||||||
|  | 	TLSCertFile          string `flag:"tls-cert-file" cfg:"tls_cert_file"` | ||||||
|  | 	TLSKeyFile           string `flag:"tls-key-file" cfg:"tls_key_file"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func legacyServerFlagset() *pflag.FlagSet { | ||||||
|  | 	flagSet := pflag.NewFlagSet("server", pflag.ExitOnError) | ||||||
|  | 
 | ||||||
|  | 	flagSet.String("metrics-address", "", "the address /metrics will be served on (e.g. \":9100\")") | ||||||
|  | 	flagSet.String("metrics-secure-address", "", "the address /metrics will be served on for HTTPS clients (e.g. \":9100\")") | ||||||
|  | 	flagSet.String("metrics-tls-cert-file", "", "path to certificate file for secure metrics server") | ||||||
|  | 	flagSet.String("metrics-tls-key-file", "", "path to private key file for secure metrics server") | ||||||
|  | 	flagSet.String("http-address", "127.0.0.1:4180", "[http://]<addr>:<port> or unix://<path> to listen on for HTTP clients") | ||||||
|  | 	flagSet.String("https-address", ":443", "<addr>:<port> to listen on for HTTPS clients") | ||||||
|  | 	flagSet.String("tls-cert-file", "", "path to certificate file") | ||||||
|  | 	flagSet.String("tls-key-file", "", "path to private key file") | ||||||
|  | 
 | ||||||
|  | 	return flagSet | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (l LegacyServer) convert() (Server, Server) { | ||||||
|  | 	appServer := Server{ | ||||||
|  | 		BindAddress:       l.HTTPAddress, | ||||||
|  | 		SecureBindAddress: l.HTTPSAddress, | ||||||
|  | 	} | ||||||
|  | 	if l.TLSKeyFile != "" || l.TLSCertFile != "" { | ||||||
|  | 		appServer.TLS = &TLS{ | ||||||
|  | 			Key: &SecretSource{ | ||||||
|  | 				FromFile: l.TLSKeyFile, | ||||||
|  | 			}, | ||||||
|  | 			Cert: &SecretSource{ | ||||||
|  | 				FromFile: l.TLSCertFile, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 		// Preserve backwards compatibility, only run one server
 | ||||||
|  | 		appServer.BindAddress = "" | ||||||
|  | 	} else { | ||||||
|  | 		// Disable the HTTPS server if there's no certificates.
 | ||||||
|  | 		// This preserves backwards compatibility.
 | ||||||
|  | 		appServer.SecureBindAddress = "" | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	metricsServer := Server{ | ||||||
|  | 		BindAddress:       l.MetricsAddress, | ||||||
|  | 		SecureBindAddress: l.MetricsSecureAddress, | ||||||
|  | 	} | ||||||
|  | 	if l.MetricsTLSKeyFile != "" || l.MetricsTLSCertFile != "" { | ||||||
|  | 		metricsServer.TLS = &TLS{ | ||||||
|  | 			Key: &SecretSource{ | ||||||
|  | 				FromFile: l.MetricsTLSKeyFile, | ||||||
|  | 			}, | ||||||
|  | 			Cert: &SecretSource{ | ||||||
|  | 				FromFile: l.MetricsTLSCertFile, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return appServer, metricsServer | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -106,6 +106,10 @@ var _ = Describe("Legacy Options", func() { | ||||||
| 
 | 
 | ||||||
| 			opts.InjectResponseHeaders = []Header{} | 			opts.InjectResponseHeaders = []Header{} | ||||||
| 
 | 
 | ||||||
|  | 			opts.Server = Server{ | ||||||
|  | 				BindAddress: "127.0.0.1:4180", | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
| 			converted, err := legacyOpts.ToOptions() | 			converted, err := legacyOpts.ToOptions() | ||||||
| 			Expect(err).ToNot(HaveOccurred()) | 			Expect(err).ToNot(HaveOccurred()) | ||||||
| 			Expect(converted).To(Equal(opts)) | 			Expect(converted).To(Equal(opts)) | ||||||
|  | @ -759,4 +763,93 @@ var _ = Describe("Legacy Options", func() { | ||||||
| 			}), | 			}), | ||||||
| 		) | 		) | ||||||
| 	}) | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("Legacy Servers", func() { | ||||||
|  | 		type legacyServersTableInput struct { | ||||||
|  | 			legacyServer          LegacyServer | ||||||
|  | 			expectedAppServer     Server | ||||||
|  | 			expectedMetricsServer Server | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		const ( | ||||||
|  | 			insecureAddr        = "127.0.0.1:8080" | ||||||
|  | 			insecureMetricsAddr = ":9090" | ||||||
|  | 			secureAddr          = ":443" | ||||||
|  | 			secureMetricsAddr   = ":9443" | ||||||
|  | 			crtPath             = "tls.crt" | ||||||
|  | 			keyPath             = "tls.key" | ||||||
|  | 		) | ||||||
|  | 
 | ||||||
|  | 		var tlsConfig = &TLS{ | ||||||
|  | 			Cert: &SecretSource{ | ||||||
|  | 				FromFile: crtPath, | ||||||
|  | 			}, | ||||||
|  | 			Key: &SecretSource{ | ||||||
|  | 				FromFile: keyPath, | ||||||
|  | 			}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		DescribeTable("should convert to app and metrics servers", | ||||||
|  | 			func(in legacyServersTableInput) { | ||||||
|  | 				appServer, metricsServer := in.legacyServer.convert() | ||||||
|  | 				Expect(appServer).To(Equal(in.expectedAppServer)) | ||||||
|  | 				Expect(metricsServer).To(Equal(in.expectedMetricsServer)) | ||||||
|  | 			}, | ||||||
|  | 			Entry("with default options only starts app HTTP server", legacyServersTableInput{ | ||||||
|  | 				legacyServer: LegacyServer{ | ||||||
|  | 					HTTPAddress:  insecureAddr, | ||||||
|  | 					HTTPSAddress: secureAddr, | ||||||
|  | 				}, | ||||||
|  | 				expectedAppServer: Server{ | ||||||
|  | 					BindAddress: insecureAddr, | ||||||
|  | 				}, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with TLS options specified only starts app HTTPS server", legacyServersTableInput{ | ||||||
|  | 				legacyServer: LegacyServer{ | ||||||
|  | 					HTTPAddress:  insecureAddr, | ||||||
|  | 					HTTPSAddress: secureAddr, | ||||||
|  | 					TLSKeyFile:   keyPath, | ||||||
|  | 					TLSCertFile:  crtPath, | ||||||
|  | 				}, | ||||||
|  | 				expectedAppServer: Server{ | ||||||
|  | 					SecureBindAddress: secureAddr, | ||||||
|  | 					TLS:               tlsConfig, | ||||||
|  | 				}, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with metrics HTTP and HTTPS addresses", legacyServersTableInput{ | ||||||
|  | 				legacyServer: LegacyServer{ | ||||||
|  | 					HTTPAddress:          insecureAddr, | ||||||
|  | 					HTTPSAddress:         secureAddr, | ||||||
|  | 					MetricsAddress:       insecureMetricsAddr, | ||||||
|  | 					MetricsSecureAddress: secureMetricsAddr, | ||||||
|  | 				}, | ||||||
|  | 				expectedAppServer: Server{ | ||||||
|  | 					BindAddress: insecureAddr, | ||||||
|  | 				}, | ||||||
|  | 				expectedMetricsServer: Server{ | ||||||
|  | 					BindAddress:       insecureMetricsAddr, | ||||||
|  | 					SecureBindAddress: secureMetricsAddr, | ||||||
|  | 				}, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with metrics HTTPS and tls cert/key", legacyServersTableInput{ | ||||||
|  | 				legacyServer: LegacyServer{ | ||||||
|  | 					HTTPAddress:          insecureAddr, | ||||||
|  | 					HTTPSAddress:         secureAddr, | ||||||
|  | 					MetricsAddress:       insecureMetricsAddr, | ||||||
|  | 					MetricsSecureAddress: secureMetricsAddr, | ||||||
|  | 					MetricsTLSKeyFile:    keyPath, | ||||||
|  | 					MetricsTLSCertFile:   crtPath, | ||||||
|  | 				}, | ||||||
|  | 				expectedAppServer: Server{ | ||||||
|  | 					BindAddress: insecureAddr, | ||||||
|  | 				}, | ||||||
|  | 				expectedMetricsServer: Server{ | ||||||
|  | 					BindAddress:       insecureMetricsAddr, | ||||||
|  | 					SecureBindAddress: secureMetricsAddr, | ||||||
|  | 					TLS:               tlsConfig, | ||||||
|  | 				}, | ||||||
|  | 			}), | ||||||
|  | 		) | ||||||
|  | 
 | ||||||
|  | 	}) | ||||||
| }) | }) | ||||||
|  |  | ||||||
|  | @ -22,9 +22,6 @@ type Options struct { | ||||||
| 	ProxyPrefix        string   `flag:"proxy-prefix" cfg:"proxy_prefix"` | 	ProxyPrefix        string   `flag:"proxy-prefix" cfg:"proxy_prefix"` | ||||||
| 	PingPath           string   `flag:"ping-path" cfg:"ping_path"` | 	PingPath           string   `flag:"ping-path" cfg:"ping_path"` | ||||||
| 	PingUserAgent      string   `flag:"ping-user-agent" cfg:"ping_user_agent"` | 	PingUserAgent      string   `flag:"ping-user-agent" cfg:"ping_user_agent"` | ||||||
| 	MetricsAddress     string   `flag:"metrics-address" cfg:"metrics_address"` |  | ||||||
| 	HTTPAddress        string   `flag:"http-address" cfg:"http_address"` |  | ||||||
| 	HTTPSAddress       string   `flag:"https-address" cfg:"https_address"` |  | ||||||
| 	ReverseProxy       bool     `flag:"reverse-proxy" cfg:"reverse_proxy"` | 	ReverseProxy       bool     `flag:"reverse-proxy" cfg:"reverse_proxy"` | ||||||
| 	RealClientIPHeader string   `flag:"real-client-ip-header" cfg:"real_client_ip_header"` | 	RealClientIPHeader string   `flag:"real-client-ip-header" cfg:"real_client_ip_header"` | ||||||
| 	TrustedIPs         []string `flag:"trusted-ip" cfg:"trusted_ips"` | 	TrustedIPs         []string `flag:"trusted-ip" cfg:"trusted_ips"` | ||||||
|  | @ -33,8 +30,6 @@ type Options struct { | ||||||
| 	ClientID           string   `flag:"client-id" cfg:"client_id"` | 	ClientID           string   `flag:"client-id" cfg:"client_id"` | ||||||
| 	ClientSecret       string   `flag:"client-secret" cfg:"client_secret"` | 	ClientSecret       string   `flag:"client-secret" cfg:"client_secret"` | ||||||
| 	ClientSecretFile   string   `flag:"client-secret-file" cfg:"client_secret_file"` | 	ClientSecretFile   string   `flag:"client-secret-file" cfg:"client_secret_file"` | ||||||
| 	TLSCertFile        string   `flag:"tls-cert-file" cfg:"tls_cert_file"` |  | ||||||
| 	TLSKeyFile         string   `flag:"tls-key-file" cfg:"tls_key_file"` |  | ||||||
| 
 | 
 | ||||||
| 	AuthenticatedEmailsFile  string   `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` | 	AuthenticatedEmailsFile  string   `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` | ||||||
| 	KeycloakGroups           []string `flag:"keycloak-group" cfg:"keycloak_groups"` | 	KeycloakGroups           []string `flag:"keycloak-group" cfg:"keycloak_groups"` | ||||||
|  | @ -68,6 +63,9 @@ type Options struct { | ||||||
| 	InjectRequestHeaders  []Header `cfg:",internal"` | 	InjectRequestHeaders  []Header `cfg:",internal"` | ||||||
| 	InjectResponseHeaders []Header `cfg:",internal"` | 	InjectResponseHeaders []Header `cfg:",internal"` | ||||||
| 
 | 
 | ||||||
|  | 	Server        Server `cfg:",internal"` | ||||||
|  | 	MetricsServer Server `cfg:",internal"` | ||||||
|  | 
 | ||||||
| 	SkipAuthRegex         []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` | 	SkipAuthRegex         []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` | ||||||
| 	SkipAuthRoutes        []string `flag:"skip-auth-route" cfg:"skip_auth_routes"` | 	SkipAuthRoutes        []string `flag:"skip-auth-route" cfg:"skip_auth_routes"` | ||||||
| 	SkipJwtBearerTokens   bool     `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"` | 	SkipJwtBearerTokens   bool     `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"` | ||||||
|  | @ -136,10 +134,7 @@ func NewOptions() *Options { | ||||||
| 	return &Options{ | 	return &Options{ | ||||||
| 		ProxyPrefix:                      "/oauth2", | 		ProxyPrefix:                      "/oauth2", | ||||||
| 		ProviderType:                     "google", | 		ProviderType:                     "google", | ||||||
| 		MetricsAddress:                   "", |  | ||||||
| 		PingPath:                         "/ping", | 		PingPath:                         "/ping", | ||||||
| 		HTTPAddress:                      "127.0.0.1:4180", |  | ||||||
| 		HTTPSAddress:                     ":443", |  | ||||||
| 		RealClientIPHeader:               "X-Real-IP", | 		RealClientIPHeader:               "X-Real-IP", | ||||||
| 		ForceHTTPS:                       false, | 		ForceHTTPS:                       false, | ||||||
| 		Cookie:                           cookieDefaults(), | 		Cookie:                           cookieDefaults(), | ||||||
|  | @ -162,14 +157,10 @@ func NewOptions() *Options { | ||||||
| func NewFlagSet() *pflag.FlagSet { | func NewFlagSet() *pflag.FlagSet { | ||||||
| 	flagSet := pflag.NewFlagSet("oauth2-proxy", pflag.ExitOnError) | 	flagSet := pflag.NewFlagSet("oauth2-proxy", pflag.ExitOnError) | ||||||
| 
 | 
 | ||||||
| 	flagSet.String("http-address", "127.0.0.1:4180", "[http://]<addr>:<port> or unix://<path> to listen on for HTTP clients") |  | ||||||
| 	flagSet.String("https-address", ":443", "<addr>:<port> to listen on for HTTPS clients") |  | ||||||
| 	flagSet.Bool("reverse-proxy", false, "are we running behind a reverse proxy, controls whether headers like X-Real-Ip are accepted") | 	flagSet.Bool("reverse-proxy", false, "are we running behind a reverse proxy, controls whether headers like X-Real-Ip are accepted") | ||||||
| 	flagSet.String("real-client-ip-header", "X-Real-IP", "Header used to determine the real IP of the client (one of: X-Forwarded-For, X-Real-IP, or X-ProxyUser-IP)") | 	flagSet.String("real-client-ip-header", "X-Real-IP", "Header used to determine the real IP of the client (one of: X-Forwarded-For, X-Real-IP, or X-ProxyUser-IP)") | ||||||
| 	flagSet.StringSlice("trusted-ip", []string{}, "list of IPs or CIDR ranges to allow to bypass authentication. WARNING: trusting by IP has inherent security flaws, read the configuration documentation for more information.") | 	flagSet.StringSlice("trusted-ip", []string{}, "list of IPs or CIDR ranges to allow to bypass authentication. WARNING: trusting by IP has inherent security flaws, read the configuration documentation for more information.") | ||||||
| 	flagSet.Bool("force-https", false, "force HTTPS redirect for HTTP requests") | 	flagSet.Bool("force-https", false, "force HTTPS redirect for HTTP requests") | ||||||
| 	flagSet.String("tls-cert-file", "", "path to certificate file") |  | ||||||
| 	flagSet.String("tls-key-file", "", "path to private key file") |  | ||||||
| 	flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"") | 	flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"") | ||||||
| 	flagSet.StringSlice("skip-auth-regex", []string{}, "(DEPRECATED for --skip-auth-route) bypass authentication for requests path's that match (may be given multiple times)") | 	flagSet.StringSlice("skip-auth-regex", []string{}, "(DEPRECATED for --skip-auth-route) bypass authentication for requests path's that match (may be given multiple times)") | ||||||
| 	flagSet.StringSlice("skip-auth-route", []string{}, "bypass authentication for requests that match the method & path. Format: method=path_regex OR path_regex alone for all methods") | 	flagSet.StringSlice("skip-auth-route", []string{}, "bypass authentication for requests that match the method & path. Format: method=path_regex OR path_regex alone for all methods") | ||||||
|  | @ -204,7 +195,6 @@ func NewFlagSet() *pflag.FlagSet { | ||||||
| 	flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)") | 	flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)") | ||||||
| 	flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks") | 	flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks") | ||||||
| 	flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks") | 	flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks") | ||||||
| 	flagSet.String("metrics-address", "", "the address /metrics will be served on (e.g. \":9100\")") |  | ||||||
| 	flagSet.String("session-store-type", "cookie", "the session storage provider to use") | 	flagSet.String("session-store-type", "cookie", "the session storage provider to use") | ||||||
| 	flagSet.Bool("session-cookie-minimal", false, "strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only)") | 	flagSet.Bool("session-cookie-minimal", false, "strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only)") | ||||||
| 	flagSet.String("redis-connection-url", "", "URL of redis server for redis session storage (eg: redis://HOST[:PORT])") | 	flagSet.String("redis-connection-url", "", "URL of redis server for redis session storage (eg: redis://HOST[:PORT])") | ||||||
|  |  | ||||||
|  | @ -0,0 +1,27 @@ | ||||||
|  | package options | ||||||
|  | 
 | ||||||
|  | // Server represents the configuration for an HTTP(S) server
 | ||||||
|  | type Server struct { | ||||||
|  | 	// BindAddress is the the address on which to serve traffic.
 | ||||||
|  | 	// Leave blank or set to "-" to disable.
 | ||||||
|  | 	BindAddress string | ||||||
|  | 
 | ||||||
|  | 	// SecureBindAddress is the the address on which to serve secure traffic.
 | ||||||
|  | 	// Leave blank or set to "-" to disable.
 | ||||||
|  | 	SecureBindAddress string | ||||||
|  | 
 | ||||||
|  | 	// TLS contains the information for loading the certificate and key for the
 | ||||||
|  | 	// secure traffic.
 | ||||||
|  | 	TLS *TLS | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // TLS contains the information for loading a TLS certifcate and key.
 | ||||||
|  | type TLS struct { | ||||||
|  | 	// Key is the the TLS key data to use.
 | ||||||
|  | 	// Typically this will come from a file.
 | ||||||
|  | 	Key *SecretSource | ||||||
|  | 
 | ||||||
|  | 	// Cert is the TLS certificate data to use.
 | ||||||
|  | 	// Typically this will come from a file.
 | ||||||
|  | 	Cert *SecretSource | ||||||
|  | } | ||||||
|  | @ -0,0 +1,88 @@ | ||||||
|  | package http | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"crypto/rsa" | ||||||
|  | 	"crypto/tls" | ||||||
|  | 	"crypto/x509" | ||||||
|  | 	"crypto/x509/pkix" | ||||||
|  | 	"encoding/pem" | ||||||
|  | 	"math/big" | ||||||
|  | 	"net" | ||||||
|  | 	"net/http" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var certData []byte | ||||||
|  | var certDataSource, keyDataSource options.SecretSource | ||||||
|  | var client *http.Client | ||||||
|  | 
 | ||||||
|  | func TestHTTPSuite(t *testing.T) { | ||||||
|  | 	logger.SetOutput(GinkgoWriter) | ||||||
|  | 	logger.SetErrOutput(GinkgoWriter) | ||||||
|  | 
 | ||||||
|  | 	RegisterFailHandler(Fail) | ||||||
|  | 	RunSpecs(t, "HTTP") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | var _ = BeforeSuite(func() { | ||||||
|  | 	By("Generating a self-signed cert for TLS tests", func() { | ||||||
|  | 		priv, err := rsa.GenerateKey(rand.Reader, 2048) | ||||||
|  | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 		keyOut := bytes.NewBuffer(nil) | ||||||
|  | 		privBytes, err := x509.MarshalPKCS8PrivateKey(priv) | ||||||
|  | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 		Expect(pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})).To(Succeed()) | ||||||
|  | 		keyDataSource.Value = keyOut.Bytes() | ||||||
|  | 
 | ||||||
|  | 		serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) | ||||||
|  | 		serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) | ||||||
|  | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 		template := x509.Certificate{ | ||||||
|  | 			SerialNumber: serialNumber, | ||||||
|  | 			Subject: pkix.Name{ | ||||||
|  | 				Organization: []string{"OAuth2 Proxy Test Suite"}, | ||||||
|  | 			}, | ||||||
|  | 			NotBefore:   time.Now(), | ||||||
|  | 			NotAfter:    time.Now().Add(time.Hour), | ||||||
|  | 			IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, | ||||||
|  | 			KeyUsage:    x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, | ||||||
|  | 			ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) | ||||||
|  | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 		certData = certBytes | ||||||
|  | 
 | ||||||
|  | 		certOut := bytes.NewBuffer(nil) | ||||||
|  | 		Expect(pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})).To(Succeed()) | ||||||
|  | 		certDataSource.Value = certOut.Bytes() | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	By("Setting up a http client", func() { | ||||||
|  | 		cert, err := tls.X509KeyPair(certDataSource.Value, keyDataSource.Value) | ||||||
|  | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 		certificate, err := x509.ParseCertificate(cert.Certificate[0]) | ||||||
|  | 		Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 		certpool := x509.NewCertPool() | ||||||
|  | 		certpool.AddCert(certificate) | ||||||
|  | 
 | ||||||
|  | 		transport := http.DefaultTransport.(*http.Transport).Clone() | ||||||
|  | 		transport.TLSClientConfig.RootCAs = certpool | ||||||
|  | 
 | ||||||
|  | 		client = &http.Client{ | ||||||
|  | 			Transport: transport, | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | }) | ||||||
|  | @ -0,0 +1,245 @@ | ||||||
|  | package http | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"crypto/tls" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"net" | ||||||
|  | 	"net/http" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options/util" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
|  | 	"golang.org/x/sync/errgroup" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Server represents an HTTP or HTTPS server.
 | ||||||
|  | type Server interface { | ||||||
|  | 	// Start blocks and runs the server.
 | ||||||
|  | 	Start(ctx context.Context) error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Opts contains the information required to set up the server.
 | ||||||
|  | type Opts struct { | ||||||
|  | 	// Handler is the http.Handler to be used to serve http pages by the server.
 | ||||||
|  | 	Handler http.Handler | ||||||
|  | 
 | ||||||
|  | 	// BindAddress is the address the HTTP server should listen on.
 | ||||||
|  | 	BindAddress string | ||||||
|  | 
 | ||||||
|  | 	// SecureBindAddress is the address the HTTPS server should listen on.
 | ||||||
|  | 	SecureBindAddress string | ||||||
|  | 
 | ||||||
|  | 	// TLS is the TLS configuration for the server.
 | ||||||
|  | 	TLS *options.TLS | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // NewServer creates a new Server from the options given.
 | ||||||
|  | func NewServer(opts Opts) (Server, error) { | ||||||
|  | 	s := &server{ | ||||||
|  | 		handler: opts.Handler, | ||||||
|  | 	} | ||||||
|  | 	if err := s.setupListener(opts); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("error setting up listener: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if err := s.setupTLSListener(opts); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("error setting up TLS listener: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return s, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // server is an implementation of the Server interface.
 | ||||||
|  | type server struct { | ||||||
|  | 	handler http.Handler | ||||||
|  | 
 | ||||||
|  | 	listener    net.Listener | ||||||
|  | 	tlsListener net.Listener | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // setupListener sets the server listener if the HTTP server is enabled.
 | ||||||
|  | // The HTTP server can be disabled by setting the BindAddress to "-" or by
 | ||||||
|  | // leaving it empty.
 | ||||||
|  | func (s *server) setupListener(opts Opts) error { | ||||||
|  | 	if opts.BindAddress == "" || opts.BindAddress == "-" { | ||||||
|  | 		// No HTTP listener required
 | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	networkType := getNetworkScheme(opts.BindAddress) | ||||||
|  | 	listenAddr := getListenAddress(opts.BindAddress) | ||||||
|  | 
 | ||||||
|  | 	listener, err := net.Listen(networkType, listenAddr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("listen (%s, %s) failed: %v", networkType, listenAddr, err) | ||||||
|  | 	} | ||||||
|  | 	s.listener = listener | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // setupTLSListener sets the server TLS listener if the HTTPS server is enabled.
 | ||||||
|  | // The HTTPS server can be disabled by setting the SecureBindAddress to "-" or by
 | ||||||
|  | // leaving it empty.
 | ||||||
|  | func (s *server) setupTLSListener(opts Opts) error { | ||||||
|  | 	if opts.SecureBindAddress == "" || opts.SecureBindAddress == "-" { | ||||||
|  | 		// No HTTPS listener required
 | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	config := &tls.Config{ | ||||||
|  | 		MinVersion: tls.VersionTLS12, | ||||||
|  | 		MaxVersion: tls.VersionTLS13, | ||||||
|  | 		NextProtos: []string{"http/1.1"}, | ||||||
|  | 	} | ||||||
|  | 	if opts.TLS == nil { | ||||||
|  | 		return errors.New("no TLS config provided") | ||||||
|  | 	} | ||||||
|  | 	cert, err := getCertificate(opts.TLS) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("could not load certificate: %v", err) | ||||||
|  | 	} | ||||||
|  | 	config.Certificates = []tls.Certificate{cert} | ||||||
|  | 
 | ||||||
|  | 	listenAddr := getListenAddress(opts.SecureBindAddress) | ||||||
|  | 
 | ||||||
|  | 	listener, err := net.Listen("tcp", listenAddr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("listen (%s) failed: %v", listenAddr, err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	s.tlsListener = tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, config) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Start starts the HTTP and HTTPS server if applicable.
 | ||||||
|  | // It will block until the context is cancelled.
 | ||||||
|  | // If any errors occur, only the first error will be returned.
 | ||||||
|  | func (s *server) Start(ctx context.Context) error { | ||||||
|  | 	g, groupCtx := errgroup.WithContext(ctx) | ||||||
|  | 
 | ||||||
|  | 	if s.listener != nil { | ||||||
|  | 		g.Go(func() error { | ||||||
|  | 			if err := s.startServer(groupCtx, s.listener); err != nil { | ||||||
|  | 				return fmt.Errorf("error starting insecure server: %v", err) | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if s.tlsListener != nil { | ||||||
|  | 		g.Go(func() error { | ||||||
|  | 			if err := s.startServer(groupCtx, s.tlsListener); err != nil { | ||||||
|  | 				return fmt.Errorf("error starting secure server: %v", err) | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return g.Wait() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // startServer creates and starts a new server with the given listener.
 | ||||||
|  | // When the given context is cancelled the server will be shutdown.
 | ||||||
|  | // If any errors occur, only the first error will be returned.
 | ||||||
|  | func (s *server) startServer(ctx context.Context, listener net.Listener) error { | ||||||
|  | 	srv := &http.Server{Handler: s.handler} | ||||||
|  | 	g, groupCtx := errgroup.WithContext(ctx) | ||||||
|  | 
 | ||||||
|  | 	g.Go(func() error { | ||||||
|  | 		<-groupCtx.Done() | ||||||
|  | 
 | ||||||
|  | 		if err := srv.Shutdown(context.Background()); err != nil { | ||||||
|  | 			return fmt.Errorf("error shutting down server: %v", err) | ||||||
|  | 		} | ||||||
|  | 		return nil | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	g.Go(func() error { | ||||||
|  | 		if err := srv.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { | ||||||
|  | 			return fmt.Errorf("could not start server: %v", err) | ||||||
|  | 		} | ||||||
|  | 		return nil | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	return g.Wait() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getNetworkScheme gets the scheme for the HTTP server.
 | ||||||
|  | func getNetworkScheme(addr string) string { | ||||||
|  | 	var scheme string | ||||||
|  | 	i := strings.Index(addr, "://") | ||||||
|  | 	if i > -1 { | ||||||
|  | 		scheme = addr[0:i] | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	switch scheme { | ||||||
|  | 	case "", "http": | ||||||
|  | 		return "tcp" | ||||||
|  | 	default: | ||||||
|  | 		return scheme | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getListenAddress gets the address for the HTTP server.
 | ||||||
|  | func getListenAddress(addr string) string { | ||||||
|  | 	slice := strings.SplitN(addr, "//", 2) | ||||||
|  | 	return slice[len(slice)-1] | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getCertificate loads the certificate data from the TLS config.
 | ||||||
|  | func getCertificate(opts *options.TLS) (tls.Certificate, error) { | ||||||
|  | 	keyData, err := getSecretValue(opts.Key) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return tls.Certificate{}, fmt.Errorf("could not load key data: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	certData, err := getSecretValue(opts.Cert) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return tls.Certificate{}, fmt.Errorf("could not load cert data: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	cert, err := tls.X509KeyPair(certData, keyData) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return tls.Certificate{}, fmt.Errorf("could not parse certificate data: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return cert, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // getSecretValue wraps util.GetSecretValue so that we can return an error if no
 | ||||||
|  | // source is provided.
 | ||||||
|  | func getSecretValue(src *options.SecretSource) ([]byte, error) { | ||||||
|  | 	if src == nil { | ||||||
|  | 		return nil, errors.New("no configuration provided") | ||||||
|  | 	} | ||||||
|  | 	return util.GetSecretValue(src) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
 | ||||||
|  | // connections. It's used by so that dead TCP connections (e.g. closing laptop
 | ||||||
|  | // mid-download) eventually go away.
 | ||||||
|  | type tcpKeepAliveListener struct { | ||||||
|  | 	*net.TCPListener | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Accept implements the TCPListener interface.
 | ||||||
|  | // It sets the keep alive period to 3 minutes for each connection.
 | ||||||
|  | func (ln tcpKeepAliveListener) Accept() (net.Conn, error) { | ||||||
|  | 	tc, err := ln.AcceptTCP() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	err = tc.SetKeepAlive(true) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.Errorf("Error setting Keep-Alive: %v", err) | ||||||
|  | 	} | ||||||
|  | 	err = tc.SetKeepAlivePeriod(3 * time.Minute) | ||||||
|  | 	if err != nil { | ||||||
|  | 		logger.Printf("Error setting Keep-Alive period: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return tc, nil | ||||||
|  | } | ||||||
|  | @ -0,0 +1,35 @@ | ||||||
|  | package http | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 
 | ||||||
|  | 	"golang.org/x/sync/errgroup" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // NewServerGroup creates a new Server to start and gracefully stop a collection
 | ||||||
|  | // of Servers.
 | ||||||
|  | func NewServerGroup(servers ...Server) Server { | ||||||
|  | 	return &serverGroup{ | ||||||
|  | 		servers: servers, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // serverGroup manages the starting and graceful shutdown of a collection of
 | ||||||
|  | // servers.
 | ||||||
|  | type serverGroup struct { | ||||||
|  | 	servers []Server | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Start runs the servers in the server group.
 | ||||||
|  | func (s *serverGroup) Start(ctx context.Context) error { | ||||||
|  | 	g, groupCtx := errgroup.WithContext(ctx) | ||||||
|  | 
 | ||||||
|  | 	for _, server := range s.servers { | ||||||
|  | 		srv := server | ||||||
|  | 		g.Go(func() error { | ||||||
|  | 			return srv.Start(groupCtx) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return g.Wait() | ||||||
|  | } | ||||||
|  | @ -0,0 +1,102 @@ | ||||||
|  | package http | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 
 | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var _ = Describe("Server Group", func() { | ||||||
|  | 	var m1, m2, m3 *mockServer | ||||||
|  | 	var ctx context.Context | ||||||
|  | 	var cancel context.CancelFunc | ||||||
|  | 	var group Server | ||||||
|  | 
 | ||||||
|  | 	BeforeEach(func() { | ||||||
|  | 		ctx, cancel = context.WithCancel(context.Background()) | ||||||
|  | 
 | ||||||
|  | 		m1 = newMockServer() | ||||||
|  | 		m2 = newMockServer() | ||||||
|  | 		m3 = newMockServer() | ||||||
|  | 		group = NewServerGroup(m1, m2, m3) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	AfterEach(func() { | ||||||
|  | 		cancel() | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	It("starts each server in the group", func() { | ||||||
|  | 		go func() { | ||||||
|  | 			defer GinkgoRecover() | ||||||
|  | 			Expect(group.Start(ctx)).To(Succeed()) | ||||||
|  | 		}() | ||||||
|  | 
 | ||||||
|  | 		Eventually(m1.started).Should(BeClosed(), "mock server 1 not started") | ||||||
|  | 		Eventually(m2.started).Should(BeClosed(), "mock server 2 not started") | ||||||
|  | 		Eventually(m3.started).Should(BeClosed(), "mock server 3 not started") | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	It("stop each server in the group when the context is cancelled", func() { | ||||||
|  | 		go func() { | ||||||
|  | 			defer GinkgoRecover() | ||||||
|  | 			Expect(group.Start(ctx)).To(Succeed()) | ||||||
|  | 		}() | ||||||
|  | 
 | ||||||
|  | 		cancel() | ||||||
|  | 		Eventually(m1.stopped).Should(BeClosed(), "mock server 1 not stopped") | ||||||
|  | 		Eventually(m2.stopped).Should(BeClosed(), "mock server 2 not stopped") | ||||||
|  | 		Eventually(m3.stopped).Should(BeClosed(), "mock server 3 not stopped") | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	It("stop each server in the group when the an error occurs", func() { | ||||||
|  | 		err := errors.New("server error") | ||||||
|  | 		go func() { | ||||||
|  | 			defer GinkgoRecover() | ||||||
|  | 			Expect(group.Start(ctx)).To(MatchError(err)) | ||||||
|  | 		}() | ||||||
|  | 
 | ||||||
|  | 		m2.errors <- err | ||||||
|  | 		Eventually(m1.stopped).Should(BeClosed(), "mock server 1 not stopped") | ||||||
|  | 		Eventually(m2.stopped).Should(BeClosed(), "mock server 2 not stopped") | ||||||
|  | 		Eventually(m3.stopped).Should(BeClosed(), "mock server 3 not stopped") | ||||||
|  | 	}) | ||||||
|  | }) | ||||||
|  | 
 | ||||||
|  | // mockServer is used to test the server group can start
 | ||||||
|  | // and stop multiple servers simultaneously.
 | ||||||
|  | type mockServer struct { | ||||||
|  | 	started     chan struct{} | ||||||
|  | 	startClosed bool | ||||||
|  | 	stopped     chan struct{} | ||||||
|  | 	stopClosed  bool | ||||||
|  | 	errors      chan error | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func newMockServer() *mockServer { | ||||||
|  | 	return &mockServer{ | ||||||
|  | 		started: make(chan struct{}), | ||||||
|  | 		stopped: make(chan struct{}), | ||||||
|  | 		errors:  make(chan error), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m *mockServer) Start(ctx context.Context) error { | ||||||
|  | 	if !m.startClosed { | ||||||
|  | 		close(m.started) | ||||||
|  | 		m.startClosed = true | ||||||
|  | 	} | ||||||
|  | 	defer func() { | ||||||
|  | 		if !m.stopClosed { | ||||||
|  | 			close(m.stopped) | ||||||
|  | 			m.stopClosed = true | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 	select { | ||||||
|  | 	case <-ctx.Done(): | ||||||
|  | 		return nil | ||||||
|  | 	case err := <-m.errors: | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -0,0 +1,472 @@ | ||||||
|  | package http | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"net/http" | ||||||
|  | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const hello = "Hello World!" | ||||||
|  | 
 | ||||||
|  | var _ = Describe("Server", func() { | ||||||
|  | 	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||||
|  | 		rw.Write([]byte(hello)) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("NewServer", func() { | ||||||
|  | 		type newServerTableInput struct { | ||||||
|  | 			opts               Opts | ||||||
|  | 			expectedErr        error | ||||||
|  | 			expectHTTPListener bool | ||||||
|  | 			expectTLSListener  bool | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		DescribeTable("When creating the new server from the options", func(in *newServerTableInput) { | ||||||
|  | 			srv, err := NewServer(in.opts) | ||||||
|  | 			if in.expectedErr != nil { | ||||||
|  | 				Expect(err).To(MatchError(ContainSubstring(in.expectedErr.Error()))) | ||||||
|  | 				Expect(srv).To(BeNil()) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 			s, ok := srv.(*server) | ||||||
|  | 			Expect(ok).To(BeTrue()) | ||||||
|  | 
 | ||||||
|  | 			Expect(s.listener != nil).To(Equal(in.expectHTTPListener)) | ||||||
|  | 			if in.expectHTTPListener { | ||||||
|  | 				Expect(s.listener.Close()).To(Succeed()) | ||||||
|  | 			} | ||||||
|  | 			Expect(s.tlsListener != nil).To(Equal(in.expectTLSListener)) | ||||||
|  | 			if in.expectTLSListener { | ||||||
|  | 				Expect(s.tlsListener.Close()).To(Succeed()) | ||||||
|  | 			} | ||||||
|  | 		}, | ||||||
|  | 			Entry("with a valid http bind address", &newServerTableInput{ | ||||||
|  | 				opts: Opts{ | ||||||
|  | 					Handler:     handler, | ||||||
|  | 					BindAddress: "127.0.0.1:0", | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:        nil, | ||||||
|  | 				expectHTTPListener: true, | ||||||
|  | 				expectTLSListener:  false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a valid https bind address, with no TLS config", &newServerTableInput{ | ||||||
|  | 				opts: Opts{ | ||||||
|  | 					Handler:           handler, | ||||||
|  | 					SecureBindAddress: "127.0.0.1:0", | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:        errors.New("error setting up TLS listener: no TLS config provided"), | ||||||
|  | 				expectHTTPListener: false, | ||||||
|  | 				expectTLSListener:  false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a valid https bind address, and valid TLS config", &newServerTableInput{ | ||||||
|  | 				opts: Opts{ | ||||||
|  | 					Handler:           handler, | ||||||
|  | 					SecureBindAddress: "127.0.0.1:0", | ||||||
|  | 					TLS: &options.TLS{ | ||||||
|  | 						Key:  &keyDataSource, | ||||||
|  | 						Cert: &certDataSource, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:        nil, | ||||||
|  | 				expectHTTPListener: false, | ||||||
|  | 				expectTLSListener:  true, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a both a valid http and valid https bind address, and valid TLS config", &newServerTableInput{ | ||||||
|  | 				opts: Opts{ | ||||||
|  | 					Handler:           handler, | ||||||
|  | 					BindAddress:       "127.0.0.1:0", | ||||||
|  | 					SecureBindAddress: "127.0.0.1:0", | ||||||
|  | 					TLS: &options.TLS{ | ||||||
|  | 						Key:  &keyDataSource, | ||||||
|  | 						Cert: &certDataSource, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:        nil, | ||||||
|  | 				expectHTTPListener: true, | ||||||
|  | 				expectTLSListener:  true, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a \"-\" for the bind address", &newServerTableInput{ | ||||||
|  | 				opts: Opts{ | ||||||
|  | 					Handler:     handler, | ||||||
|  | 					BindAddress: "-", | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:        nil, | ||||||
|  | 				expectHTTPListener: false, | ||||||
|  | 				expectTLSListener:  false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with a \"-\" for the secure bind address", &newServerTableInput{ | ||||||
|  | 				opts: Opts{ | ||||||
|  | 					Handler:           handler, | ||||||
|  | 					SecureBindAddress: "-", | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:        nil, | ||||||
|  | 				expectHTTPListener: false, | ||||||
|  | 				expectTLSListener:  false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with an invalid bind address scheme", &newServerTableInput{ | ||||||
|  | 				opts: Opts{ | ||||||
|  | 					Handler:     handler, | ||||||
|  | 					BindAddress: "invalid://127.0.0.1:0", | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:        errors.New("error setting up listener: listen (invalid, 127.0.0.1:0) failed: listen invalid: unknown network invalid"), | ||||||
|  | 				expectHTTPListener: false, | ||||||
|  | 				expectTLSListener:  false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with an invalid secure bind address scheme", &newServerTableInput{ | ||||||
|  | 				opts: Opts{ | ||||||
|  | 					Handler:           handler, | ||||||
|  | 					SecureBindAddress: "invalid://127.0.0.1:0", | ||||||
|  | 					TLS: &options.TLS{ | ||||||
|  | 						Key:  &keyDataSource, | ||||||
|  | 						Cert: &certDataSource, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:        nil, | ||||||
|  | 				expectHTTPListener: false, | ||||||
|  | 				expectTLSListener:  true, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with an invalid bind address port", &newServerTableInput{ | ||||||
|  | 				opts: Opts{ | ||||||
|  | 					Handler:     handler, | ||||||
|  | 					BindAddress: "127.0.0.1:a", | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:        errors.New("error setting up listener: listen (tcp, 127.0.0.1:a) failed: listen tcp: lookup tcp/a: "), | ||||||
|  | 				expectHTTPListener: false, | ||||||
|  | 				expectTLSListener:  false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with an invalid secure bind address port", &newServerTableInput{ | ||||||
|  | 				opts: Opts{ | ||||||
|  | 					Handler:           handler, | ||||||
|  | 					SecureBindAddress: "127.0.0.1:a", | ||||||
|  | 					TLS: &options.TLS{ | ||||||
|  | 						Key:  &keyDataSource, | ||||||
|  | 						Cert: &certDataSource, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:        errors.New("error setting up TLS listener: listen (127.0.0.1:a) failed: listen tcp: lookup tcp/a: "), | ||||||
|  | 				expectHTTPListener: false, | ||||||
|  | 				expectTLSListener:  false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with an invalid TLS key", &newServerTableInput{ | ||||||
|  | 				opts: Opts{ | ||||||
|  | 					Handler:           handler, | ||||||
|  | 					SecureBindAddress: "127.0.0.1:0", | ||||||
|  | 					TLS: &options.TLS{ | ||||||
|  | 						Key: &options.SecretSource{ | ||||||
|  | 							Value: []byte("invalid"), | ||||||
|  | 						}, | ||||||
|  | 						Cert: &certDataSource, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:        errors.New("error setting up TLS listener: could not load certificate: could not parse certificate data: tls: failed to find any PEM data in key input"), | ||||||
|  | 				expectHTTPListener: false, | ||||||
|  | 				expectTLSListener:  false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with an invalid TLS cert", &newServerTableInput{ | ||||||
|  | 				opts: Opts{ | ||||||
|  | 					Handler:           handler, | ||||||
|  | 					SecureBindAddress: "127.0.0.1:0", | ||||||
|  | 					TLS: &options.TLS{ | ||||||
|  | 						Key: &keyDataSource, | ||||||
|  | 						Cert: &options.SecretSource{ | ||||||
|  | 							Value: []byte("invalid"), | ||||||
|  | 						}, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:        errors.New("error setting up TLS listener: could not load certificate: could not parse certificate data: tls: failed to find any PEM data in certificate input"), | ||||||
|  | 				expectHTTPListener: false, | ||||||
|  | 				expectTLSListener:  false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with no TLS key", &newServerTableInput{ | ||||||
|  | 				opts: Opts{ | ||||||
|  | 					Handler:           handler, | ||||||
|  | 					SecureBindAddress: "127.0.0.1:0", | ||||||
|  | 					TLS: &options.TLS{ | ||||||
|  | 						Cert: &certDataSource, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:        errors.New("error setting up TLS listener: could not load certificate: could not load key data: no configuration provided"), | ||||||
|  | 				expectHTTPListener: false, | ||||||
|  | 				expectTLSListener:  false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("with no TLS cert", &newServerTableInput{ | ||||||
|  | 				opts: Opts{ | ||||||
|  | 					Handler:           handler, | ||||||
|  | 					SecureBindAddress: "127.0.0.1:0", | ||||||
|  | 					TLS: &options.TLS{ | ||||||
|  | 						Key: &keyDataSource, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:        errors.New("error setting up TLS listener: could not load certificate: could not load cert data: no configuration provided"), | ||||||
|  | 				expectHTTPListener: false, | ||||||
|  | 				expectTLSListener:  false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the bind address is prefixed with the http scheme", &newServerTableInput{ | ||||||
|  | 				opts: Opts{ | ||||||
|  | 					Handler:     handler, | ||||||
|  | 					BindAddress: "http://127.0.0.1:0", | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:        nil, | ||||||
|  | 				expectHTTPListener: true, | ||||||
|  | 				expectTLSListener:  false, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the secure bind address is prefixed with the https scheme", &newServerTableInput{ | ||||||
|  | 				opts: Opts{ | ||||||
|  | 					Handler:           handler, | ||||||
|  | 					SecureBindAddress: "https://127.0.0.1:0", | ||||||
|  | 					TLS: &options.TLS{ | ||||||
|  | 						Key:  &keyDataSource, | ||||||
|  | 						Cert: &certDataSource, | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:        nil, | ||||||
|  | 				expectHTTPListener: false, | ||||||
|  | 				expectTLSListener:  true, | ||||||
|  | 			}), | ||||||
|  | 		) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("Start", func() { | ||||||
|  | 		var srv Server | ||||||
|  | 		var ctx context.Context | ||||||
|  | 		var cancel context.CancelFunc | ||||||
|  | 
 | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			ctx, cancel = context.WithCancel(context.Background()) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		AfterEach(func() { | ||||||
|  | 			cancel() | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("with an http server", func() { | ||||||
|  | 			var listenAddr string | ||||||
|  | 
 | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				var err error | ||||||
|  | 				srv, err = NewServer(Opts{ | ||||||
|  | 					Handler:     handler, | ||||||
|  | 					BindAddress: "127.0.0.1:0", | ||||||
|  | 				}) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 				s, ok := srv.(*server) | ||||||
|  | 				Expect(ok).To(BeTrue()) | ||||||
|  | 
 | ||||||
|  | 				listenAddr = fmt.Sprintf("http://%s/", s.listener.Addr().String()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Starts the server and serves the handler", func() { | ||||||
|  | 				go func() { | ||||||
|  | 					defer GinkgoRecover() | ||||||
|  | 					Expect(srv.Start(ctx)).To(Succeed()) | ||||||
|  | 				}() | ||||||
|  | 
 | ||||||
|  | 				resp, err := client.Get(listenAddr) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(resp.StatusCode).To(Equal(http.StatusOK)) | ||||||
|  | 
 | ||||||
|  | 				body, err := ioutil.ReadAll(resp.Body) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(string(body)).To(Equal(hello)) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Stops the server when the context is cancelled", func() { | ||||||
|  | 				go func() { | ||||||
|  | 					defer GinkgoRecover() | ||||||
|  | 					Expect(srv.Start(ctx)).To(Succeed()) | ||||||
|  | 				}() | ||||||
|  | 
 | ||||||
|  | 				_, err := client.Get(listenAddr) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 				cancel() | ||||||
|  | 
 | ||||||
|  | 				Eventually(func() error { | ||||||
|  | 					_, err := client.Get(listenAddr) | ||||||
|  | 					return err | ||||||
|  | 				}).Should(HaveOccurred()) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("with an https server", func() { | ||||||
|  | 			var secureListenAddr string | ||||||
|  | 
 | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				var err error | ||||||
|  | 				srv, err = NewServer(Opts{ | ||||||
|  | 					Handler:           handler, | ||||||
|  | 					SecureBindAddress: "127.0.0.1:0", | ||||||
|  | 					TLS: &options.TLS{ | ||||||
|  | 						Key:  &keyDataSource, | ||||||
|  | 						Cert: &certDataSource, | ||||||
|  | 					}, | ||||||
|  | 				}) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 				s, ok := srv.(*server) | ||||||
|  | 				Expect(ok).To(BeTrue()) | ||||||
|  | 
 | ||||||
|  | 				secureListenAddr = fmt.Sprintf("https://%s/", s.tlsListener.Addr().String()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Starts the server and serves the handler", func() { | ||||||
|  | 				go func() { | ||||||
|  | 					defer GinkgoRecover() | ||||||
|  | 					Expect(srv.Start(ctx)).To(Succeed()) | ||||||
|  | 				}() | ||||||
|  | 
 | ||||||
|  | 				resp, err := client.Get(secureListenAddr) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(resp.StatusCode).To(Equal(http.StatusOK)) | ||||||
|  | 
 | ||||||
|  | 				body, err := ioutil.ReadAll(resp.Body) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(string(body)).To(Equal(hello)) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Stops the server when the context is cancelled", func() { | ||||||
|  | 				go func() { | ||||||
|  | 					defer GinkgoRecover() | ||||||
|  | 					Expect(srv.Start(ctx)).To(Succeed()) | ||||||
|  | 				}() | ||||||
|  | 
 | ||||||
|  | 				_, err := client.Get(secureListenAddr) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 				cancel() | ||||||
|  | 
 | ||||||
|  | 				Eventually(func() error { | ||||||
|  | 					_, err := client.Get(secureListenAddr) | ||||||
|  | 					return err | ||||||
|  | 				}).Should(HaveOccurred()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Serves the certificate provided", func() { | ||||||
|  | 				go func() { | ||||||
|  | 					defer GinkgoRecover() | ||||||
|  | 					Expect(srv.Start(ctx)).To(Succeed()) | ||||||
|  | 				}() | ||||||
|  | 
 | ||||||
|  | 				resp, err := client.Get(secureListenAddr) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(resp.StatusCode).To(Equal(http.StatusOK)) | ||||||
|  | 
 | ||||||
|  | 				Expect(resp.TLS.VerifiedChains).Should(HaveLen(1)) | ||||||
|  | 				Expect(resp.TLS.VerifiedChains[0]).Should(HaveLen(1)) | ||||||
|  | 				Expect(resp.TLS.VerifiedChains[0][0].Raw).Should(Equal(certData)) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("with both an http and an https server", func() { | ||||||
|  | 			var listenAddr, secureListenAddr string | ||||||
|  | 
 | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				var err error | ||||||
|  | 				srv, err = NewServer(Opts{ | ||||||
|  | 					Handler:           handler, | ||||||
|  | 					BindAddress:       "127.0.0.1:0", | ||||||
|  | 					SecureBindAddress: "127.0.0.1:0", | ||||||
|  | 					TLS: &options.TLS{ | ||||||
|  | 						Key:  &keyDataSource, | ||||||
|  | 						Cert: &certDataSource, | ||||||
|  | 					}, | ||||||
|  | 				}) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 				s, ok := srv.(*server) | ||||||
|  | 				Expect(ok).To(BeTrue()) | ||||||
|  | 
 | ||||||
|  | 				listenAddr = fmt.Sprintf("http://%s/", s.listener.Addr().String()) | ||||||
|  | 				secureListenAddr = fmt.Sprintf("https://%s/", s.tlsListener.Addr().String()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Starts the server and serves the handler on http", func() { | ||||||
|  | 				go func() { | ||||||
|  | 					defer GinkgoRecover() | ||||||
|  | 					Expect(srv.Start(ctx)).To(Succeed()) | ||||||
|  | 				}() | ||||||
|  | 
 | ||||||
|  | 				resp, err := client.Get(listenAddr) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(resp.StatusCode).To(Equal(http.StatusOK)) | ||||||
|  | 
 | ||||||
|  | 				body, err := ioutil.ReadAll(resp.Body) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(string(body)).To(Equal(hello)) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Starts the server and serves the handler on https", func() { | ||||||
|  | 				go func() { | ||||||
|  | 					defer GinkgoRecover() | ||||||
|  | 					Expect(srv.Start(ctx)).To(Succeed()) | ||||||
|  | 				}() | ||||||
|  | 
 | ||||||
|  | 				resp, err := client.Get(secureListenAddr) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(resp.StatusCode).To(Equal(http.StatusOK)) | ||||||
|  | 
 | ||||||
|  | 				body, err := ioutil.ReadAll(resp.Body) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(string(body)).To(Equal(hello)) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Stops both servers when the context is cancelled", func() { | ||||||
|  | 				go func() { | ||||||
|  | 					defer GinkgoRecover() | ||||||
|  | 					Expect(srv.Start(ctx)).To(Succeed()) | ||||||
|  | 				}() | ||||||
|  | 
 | ||||||
|  | 				_, err := client.Get(listenAddr) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				_, err = client.Get(secureListenAddr) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 				cancel() | ||||||
|  | 
 | ||||||
|  | 				Eventually(func() error { | ||||||
|  | 					_, err := client.Get(listenAddr) | ||||||
|  | 					return err | ||||||
|  | 				}).Should(HaveOccurred()) | ||||||
|  | 				Eventually(func() error { | ||||||
|  | 					_, err := client.Get(secureListenAddr) | ||||||
|  | 					return err | ||||||
|  | 				}).Should(HaveOccurred()) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("getNetworkScheme", func() { | ||||||
|  | 		DescribeTable("should return the scheme", func(in, expected string) { | ||||||
|  | 			Expect(getNetworkScheme(in)).To(Equal(expected)) | ||||||
|  | 		}, | ||||||
|  | 			Entry("with no scheme", "127.0.0.1:0", "tcp"), | ||||||
|  | 			Entry("with a tcp scheme", "tcp://127.0.0.1:0", "tcp"), | ||||||
|  | 			Entry("with a http scheme", "http://192.168.0.1:1", "tcp"), | ||||||
|  | 			Entry("with a unix scheme", "unix://172.168.16.2:2", "unix"), | ||||||
|  | 			Entry("with a random scheme", "random://10.10.10.10:10", "random"), | ||||||
|  | 		) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("getListenAddress", func() { | ||||||
|  | 		DescribeTable("should remove the scheme", func(in, expected string) { | ||||||
|  | 			Expect(getListenAddress(in)).To(Equal(expected)) | ||||||
|  | 		}, | ||||||
|  | 			Entry("with no scheme", "127.0.0.1:0", "127.0.0.1:0"), | ||||||
|  | 			Entry("with a tcp scheme", "tcp://127.0.0.1:0", "127.0.0.1:0"), | ||||||
|  | 			Entry("with a http scheme", "http://192.168.0.1:1", "192.168.0.1:1"), | ||||||
|  | 			Entry("with a unix scheme", "unix://172.168.16.2:2", "172.168.16.2:2"), | ||||||
|  | 			Entry("with a random scheme", "random://10.10.10.10:10", "10.10.10.10:10"), | ||||||
|  | 		) | ||||||
|  | 	}) | ||||||
|  | }) | ||||||
		Loading…
	
		Reference in New Issue