Merge pull request #1047 from oauth2-proxy/http-server
Refactor HTTP Server and add ServerGroup to handle graceful shutdown of multiple servers
This commit is contained in:
		
						commit
						6894738d97
					
				|  | @ -8,6 +8,7 @@ | |||
| 
 | ||||
| ## Changes since v7.0.1 | ||||
| 
 | ||||
| - [#1047](https://github.com/oauth2-proxy/oauth2-proxy/pull/1047) Refactor HTTP Server and add ServerGroup to handle graceful shutdown of multiple servers (@JoelSpeed) | ||||
| - [#1070](https://github.com/oauth2-proxy/oauth2-proxy/pull/1070) Refactor logging middleware to middleware package (@NickMeves) | ||||
| - [#1064](https://github.com/oauth2-proxy/oauth2-proxy/pull/1064) Add support for setting groups on session when using basic auth (@stefansedich) | ||||
| - [#1056](https://github.com/oauth2-proxy/oauth2-proxy/pull/1056) Add option for custom logos on the sign in page (@JoelSpeed) | ||||
|  |  | |||
|  | @ -117,6 +117,8 @@ They may change between releases without notice. | |||
| | `upstreams` | _[Upstreams](#upstreams)_ | Upstreams is used to configure upstream servers.<br/>Once a user is authenticated, requests to the server will be proxied to<br/>these upstream servers based on the path mappings defined in this list. | | ||||
| | `injectRequestHeaders` | _[[]Header](#header)_ | InjectRequestHeaders is used to configure headers that should be added<br/>to requests to upstream servers.<br/>Headers may source values from either the authenticated user's session<br/>or from a static secret value. | | ||||
| | `injectResponseHeaders` | _[[]Header](#header)_ | InjectResponseHeaders is used to configure headers that should be added<br/>to responses from the proxy.<br/>This is typically used when using the proxy as an external authentication<br/>provider in conjunction with another proxy such as NGINX and its<br/>auth_request module.<br/>Headers may source values from either the authenticated user's session<br/>or from a static secret value. | | ||||
| | `server` | _[Server](#server)_ | Server is used to configure the HTTP(S) server for the proxy application.<br/>You may choose to run both HTTP and HTTPS servers simultaneously.<br/>This can be done by setting the BindAddress and the SecureBindAddress simultaneously.<br/>To use the secure server you must configure a TLS certificate and key. | | ||||
| | `metricsServer` | _[Server](#server)_ | MetricsServer is used to configure the HTTP(S) server for metrics.<br/>You may choose to run both HTTP and HTTPS servers simultaneously.<br/>This can be done by setting the BindAddress and the SecureBindAddress simultaneously.<br/>To use the secure server you must configure a TLS certificate and key. | | ||||
| 
 | ||||
| ### ClaimSource | ||||
| 
 | ||||
|  | @ -172,7 +174,7 @@ make up the header value | |||
| 
 | ||||
| ### SecretSource | ||||
| 
 | ||||
| (**Appears on:** [ClaimSource](#claimsource), [HeaderValue](#headervalue)) | ||||
| (**Appears on:** [ClaimSource](#claimsource), [HeaderValue](#headervalue), [TLS](#tls)) | ||||
| 
 | ||||
| SecretSource references an individual secret value. | ||||
| Only one source within the struct should be defined at any time. | ||||
|  | @ -183,6 +185,29 @@ Only one source within the struct should be defined at any time. | |||
| | `fromEnv` | _string_ | FromEnv expects the name of an environment variable. | | ||||
| | `fromFile` | _string_ | FromFile expects a path to a file containing the secret value. | | ||||
| 
 | ||||
| ### Server | ||||
| 
 | ||||
| (**Appears on:** [AlphaOptions](#alphaoptions)) | ||||
| 
 | ||||
| Server represents the configuration for an HTTP(S) server | ||||
| 
 | ||||
| | Field | Type | Description | | ||||
| | ----- | ---- | ----------- | | ||||
| | `BindAddress` | _string_ | BindAddress is the the address on which to serve traffic.<br/>Leave blank or set to "-" to disable. | | ||||
| | `SecureBindAddress` | _string_ | SecureBindAddress is the the address on which to serve secure traffic.<br/>Leave blank or set to "-" to disable. | | ||||
| | `TLS` | _[TLS](#tls)_ | TLS contains the information for loading the certificate and key for the<br/>secure traffic. | | ||||
| 
 | ||||
| ### TLS | ||||
| 
 | ||||
| (**Appears on:** [Server](#server)) | ||||
| 
 | ||||
| TLS contains the information for loading a TLS certifcate and key. | ||||
| 
 | ||||
| | Field | Type | Description | | ||||
| | ----- | ---- | ----------- | | ||||
| | `Key` | _[SecretSource](#secretsource)_ | Key is the the TLS key data to use.<br/>Typically this will come from a file. | | ||||
| | `Cert` | _[SecretSource](#secretsource)_ | Cert is the TLS certificate data to use.<br/>Typically this will come from a file. | | ||||
| 
 | ||||
| ### Upstream | ||||
| 
 | ||||
| (**Appears on:** [Upstreams](#upstreams)) | ||||
|  |  | |||
							
								
								
									
										1
									
								
								go.mod
								
								
								
								
							
							
						
						
									
										1
									
								
								go.mod
								
								
								
								
							|  | @ -30,6 +30,7 @@ require ( | |||
| 	golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 | ||||
| 	golang.org/x/net v0.0.0-20200707034311-ab3426394381 | ||||
| 	golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d | ||||
| 	golang.org/x/sync v0.0.0-20201207232520-09787c993a3a | ||||
| 	google.golang.org/api v0.20.0 | ||||
| 	gopkg.in/natefinch/lumberjack.v2 v2.0.0 | ||||
| 	gopkg.in/square/go-jose.v2 v2.4.1 | ||||
|  |  | |||
							
								
								
									
										3
									
								
								go.sum
								
								
								
								
							
							
						
						
									
										3
									
								
								go.sum
								
								
								
								
							|  | @ -506,7 +506,10 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ | |||
| golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= | ||||
| golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20201207232520-09787c993a3a h1:DcqTD9SDLc+1P/r1EmRBwnVsrOwW+kk2vWf9n+1sGhs= | ||||
| golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
|  |  | |||
							
								
								
									
										136
									
								
								http.go
								
								
								
								
							
							
						
						
									
										136
									
								
								http.go
								
								
								
								
							|  | @ -1,136 +0,0 @@ | |||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"errors" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| ) | ||||
| 
 | ||||
| // Server represents an HTTP server
 | ||||
| type Server struct { | ||||
| 	Handler http.Handler | ||||
| 	Opts    *options.Options | ||||
| 	stop    chan struct{} // channel for waiting shutdown
 | ||||
| } | ||||
| 
 | ||||
| // ListenAndServe will serve traffic on HTTP or HTTPS depending on TLS options
 | ||||
| func (s *Server) ListenAndServe() { | ||||
| 	if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" { | ||||
| 		s.ServeHTTPS() | ||||
| 	} else { | ||||
| 		s.ServeHTTP() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // ServeHTTP constructs a net.Listener and starts handling HTTP requests
 | ||||
| func (s *Server) ServeHTTP() { | ||||
| 	HTTPAddress := s.Opts.HTTPAddress | ||||
| 	var scheme string | ||||
| 
 | ||||
| 	i := strings.Index(HTTPAddress, "://") | ||||
| 	if i > -1 { | ||||
| 		scheme = HTTPAddress[0:i] | ||||
| 	} | ||||
| 
 | ||||
| 	var networkType string | ||||
| 	switch scheme { | ||||
| 	case "", "http": | ||||
| 		networkType = "tcp" | ||||
| 	default: | ||||
| 		networkType = scheme | ||||
| 	} | ||||
| 
 | ||||
| 	slice := strings.SplitN(HTTPAddress, "//", 2) | ||||
| 	listenAddr := slice[len(slice)-1] | ||||
| 
 | ||||
| 	listener, err := net.Listen(networkType, listenAddr) | ||||
| 	if err != nil { | ||||
| 		logger.Fatalf("FATAL: listen (%s, %s) failed - %s", networkType, listenAddr, err) | ||||
| 	} | ||||
| 	logger.Printf("HTTP: listening on %s", listenAddr) | ||||
| 	s.serve(listener) | ||||
| 	logger.Printf("HTTP: closing %s", listener.Addr()) | ||||
| } | ||||
| 
 | ||||
| // ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
 | ||||
| func (s *Server) ServeHTTPS() { | ||||
| 	addr := s.Opts.HTTPSAddress | ||||
| 	config := &tls.Config{ | ||||
| 		MinVersion: tls.VersionTLS12, | ||||
| 		MaxVersion: tls.VersionTLS13, | ||||
| 	} | ||||
| 	if config.NextProtos == nil { | ||||
| 		config.NextProtos = []string{"http/1.1"} | ||||
| 	} | ||||
| 
 | ||||
| 	var err error | ||||
| 	config.Certificates = make([]tls.Certificate, 1) | ||||
| 	config.Certificates[0], err = tls.LoadX509KeyPair(s.Opts.TLSCertFile, s.Opts.TLSKeyFile) | ||||
| 	if err != nil { | ||||
| 		logger.Fatalf("FATAL: loading tls config (%s, %s) failed - %s", s.Opts.TLSCertFile, s.Opts.TLSKeyFile, err) | ||||
| 	} | ||||
| 
 | ||||
| 	ln, err := net.Listen("tcp", addr) | ||||
| 	if err != nil { | ||||
| 		logger.Fatalf("FATAL: listen (%s) failed - %s", addr, err) | ||||
| 	} | ||||
| 	logger.Printf("HTTPS: listening on %s", ln.Addr()) | ||||
| 
 | ||||
| 	tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config) | ||||
| 	s.serve(tlsListener) | ||||
| 	logger.Printf("HTTPS: closing %s", tlsListener.Addr()) | ||||
| } | ||||
| 
 | ||||
| func (s *Server) serve(listener net.Listener) { | ||||
| 	srv := &http.Server{Handler: s.Handler} | ||||
| 
 | ||||
| 	// See https://golang.org/pkg/net/http/#Server.Shutdown
 | ||||
| 	idleConnsClosed := make(chan struct{}) | ||||
| 	go func() { | ||||
| 		<-s.stop // wait notification for stopping server
 | ||||
| 
 | ||||
| 		// We received an interrupt signal, shut down.
 | ||||
| 		if err := srv.Shutdown(context.Background()); err != nil { | ||||
| 			// Error from closing listeners, or context timeout:
 | ||||
| 			logger.Printf("HTTP server Shutdown: %v", err) | ||||
| 		} | ||||
| 		close(idleConnsClosed) | ||||
| 	}() | ||||
| 
 | ||||
| 	err := srv.Serve(listener) | ||||
| 	if err != nil && !errors.Is(err, http.ErrServerClosed) { | ||||
| 		logger.Errorf("ERROR: http.Serve() - %s", err) | ||||
| 	} | ||||
| 	<-idleConnsClosed | ||||
| } | ||||
| 
 | ||||
| // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
 | ||||
| // connections. It's used by ListenAndServe and ListenAndServeTLS so
 | ||||
| // dead TCP connections (e.g. closing laptop mid-download) eventually
 | ||||
| // go away.
 | ||||
| type tcpKeepAliveListener struct { | ||||
| 	*net.TCPListener | ||||
| } | ||||
| 
 | ||||
| func (ln tcpKeepAliveListener) Accept() (net.Conn, error) { | ||||
| 	tc, err := ln.AcceptTCP() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	err = tc.SetKeepAlive(true) | ||||
| 	if err != nil { | ||||
| 		logger.Printf("Error setting Keep-Alive: %v", err) | ||||
| 	} | ||||
| 	err = tc.SetKeepAlivePeriod(3 * time.Minute) | ||||
| 	if err != nil { | ||||
| 		logger.Printf("Error setting Keep-Alive period: %v", err) | ||||
| 	} | ||||
| 	return tc, nil | ||||
| } | ||||
							
								
								
									
										39
									
								
								http_test.go
								
								
								
								
							
							
						
						
									
										39
									
								
								http_test.go
								
								
								
								
							|  | @ -1,39 +0,0 @@ | |||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
| func TestGracefulShutdown(t *testing.T) { | ||||
| 	opts := options.NewOptions() | ||||
| 	stop := make(chan struct{}, 1) | ||||
| 	srv := Server{Handler: http.DefaultServeMux, Opts: opts, stop: stop} | ||||
| 	var wg sync.WaitGroup | ||||
| 	wg.Add(1) | ||||
| 	go func() { | ||||
| 		defer wg.Done() | ||||
| 		srv.ServeHTTP() | ||||
| 	}() | ||||
| 
 | ||||
| 	stop <- struct{}{} // emulate catching signals
 | ||||
| 
 | ||||
| 	// An idiomatic for sync.WaitGroup with timeout
 | ||||
| 	c := make(chan struct{}) | ||||
| 	go func() { | ||||
| 		defer close(c) | ||||
| 		wg.Wait() | ||||
| 	}() | ||||
| 	select { | ||||
| 	case <-c: | ||||
| 	case <-time.After(1 * time.Second): | ||||
| 		t.Fatal("Server should return gracefully but timeout has occurred") | ||||
| 	} | ||||
| 
 | ||||
| 	assert.Len(t, stop, 0) // check if stop chan is empty
 | ||||
| } | ||||
							
								
								
									
										54
									
								
								main.go
								
								
								
								
							
							
						
						
									
										54
									
								
								main.go
								
								
								
								
							|  | @ -1,20 +1,15 @@ | |||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"math/rand" | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"runtime" | ||||
| 	"syscall" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/ghodss/yaml" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation" | ||||
| 	"github.com/spf13/pflag" | ||||
| ) | ||||
|  | @ -67,54 +62,9 @@ func main() { | |||
| 
 | ||||
| 	rand.Seed(time.Now().UnixNano()) | ||||
| 
 | ||||
| 	oauthProxyStop := make(chan struct{}, 1) | ||||
| 	metricsStop := startMetricsServer(opts.MetricsAddress, oauthProxyStop) | ||||
| 
 | ||||
| 	s := &Server{ | ||||
| 		Handler: oauthproxy, | ||||
| 		Opts:    opts, | ||||
| 		stop:    oauthProxyStop, | ||||
| 	if err := oauthproxy.Start(); err != nil { | ||||
| 		logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err) | ||||
| 	} | ||||
| 	// Observe signals in background goroutine.
 | ||||
| 	go func() { | ||||
| 		sigint := make(chan os.Signal, 1) | ||||
| 		signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) | ||||
| 		<-sigint | ||||
| 		s.stop <- struct{}{} // notify having caught signal stop oauthproxy
 | ||||
| 		close(metricsStop)   // and the metrics endpoint
 | ||||
| 	}() | ||||
| 	s.ListenAndServe() | ||||
| } | ||||
| 
 | ||||
| // startMetricsServer will start the metrics server on the specified address.
 | ||||
| // It always return a channel to signal stop even when it does not run.
 | ||||
| func startMetricsServer(address string, oauthProxyStop chan struct{}) chan struct{} { | ||||
| 	stop := make(chan struct{}, 1) | ||||
| 
 | ||||
| 	// Attempt to setup the metrics endpoint if we have an address
 | ||||
| 	if address != "" { | ||||
| 		s := &http.Server{Addr: address, Handler: middleware.DefaultMetricsHandler} | ||||
| 		go func() { | ||||
| 			// ListenAndServe always returns a non-nil error. After Shutdown or
 | ||||
| 			// Close, the returned error is ErrServerClosed
 | ||||
| 			if err := s.ListenAndServe(); err != http.ErrServerClosed { | ||||
| 				logger.Println(err) | ||||
| 				// Stop the metrics shutdown go routine
 | ||||
| 				close(stop) | ||||
| 				// Stop the oauthproxy server, we have encounter an unexpected error
 | ||||
| 				close(oauthProxyStop) | ||||
| 			} | ||||
| 		}() | ||||
| 
 | ||||
| 		go func() { | ||||
| 			<-stop | ||||
| 			if err := s.Shutdown(context.Background()); err != nil { | ||||
| 				logger.Print(err) | ||||
| 			} | ||||
| 		}() | ||||
| 	} | ||||
| 
 | ||||
| 	return stop | ||||
| } | ||||
| 
 | ||||
| // loadConfiguration will load in the user's configuration.
 | ||||
|  |  | |||
|  | @ -15,6 +15,7 @@ import ( | |||
| 
 | ||||
| var _ = Describe("Configuration Loading Suite", func() { | ||||
| 	const testLegacyConfig = ` | ||||
| http_address="127.0.0.1:4180" | ||||
| upstreams="http://httpbin" | ||||
| set_basic_auth="true" | ||||
| basic_auth_password="super-secret-password" | ||||
|  | @ -54,10 +55,11 @@ injectResponseHeaders: | |||
|     prefix: "Basic " | ||||
|     basicAuthPassword: | ||||
|       value: c3VwZXItc2VjcmV0LXBhc3N3b3Jk | ||||
| server: | ||||
|   bindAddress: "127.0.0.1:4180" | ||||
| ` | ||||
| 
 | ||||
| 	const testCoreConfig = ` | ||||
| http_address="0.0.0.0:4180" | ||||
| cookie_secret="OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w=" | ||||
| provider="oidc" | ||||
| email_domains="example.com" | ||||
|  | @ -82,7 +84,6 @@ redirect_url="http://localhost:4180/oauth2/callback" | |||
| 		opts, err := options.NewLegacyOptions().ToOptions() | ||||
| 		Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 		opts.HTTPAddress = "0.0.0.0:4180" | ||||
| 		opts.Cookie.Secret = "OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w=" | ||||
| 		opts.ProviderType = "oidc" | ||||
| 		opts.EmailDomains = []string{"example.com"} | ||||
|  | @ -203,7 +204,7 @@ redirect_url="http://localhost:4180/oauth2/callback" | |||
| 			configContent:      testCoreConfig, | ||||
| 			alphaConfigContent: testAlphaConfig + ":", | ||||
| 			expectedOptions:    func() *options.Options { return nil }, | ||||
| 			expectedErr:        errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 34: did not find expected key"), | ||||
| 			expectedErr:        errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 36: did not find expected key"), | ||||
| 		}), | ||||
| 		Entry("with alpha configuration and bad core configuration", loadConfigurationTableInput{ | ||||
| 			configContent:      testCoreConfig + "unknown_field=\"something\"", | ||||
|  |  | |||
|  | @ -8,8 +8,11 @@ import ( | |||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"os" | ||||
| 	"os/signal" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 	"syscall" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/justinas/alice" | ||||
|  | @ -21,6 +24,7 @@ import ( | |||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" | ||||
| 	proxyhttp "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/http" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware" | ||||
|  | @ -102,6 +106,7 @@ type OAuthProxy struct { | |||
| 	headersChain alice.Chain | ||||
| 	preAuthChain alice.Chain | ||||
| 	pageWriter   pagewriter.Writer | ||||
| 	server       proxyhttp.Server | ||||
| } | ||||
| 
 | ||||
| // NewOAuthProxy creates a new instance of OAuthProxy from the options provided
 | ||||
|  | @ -184,7 +189,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | |||
| 		return nil, fmt.Errorf("could not build headers chain: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return &OAuthProxy{ | ||||
| 	p := &OAuthProxy{ | ||||
| 		CookieName:     opts.Cookie.Name, | ||||
| 		CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"), | ||||
| 		CookieSeed:     opts.Cookie.Secret, | ||||
|  | @ -223,7 +228,60 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | |||
| 		headersChain:       headersChain, | ||||
| 		preAuthChain:       preAuthChain, | ||||
| 		pageWriter:         pageWriter, | ||||
| 	}, nil | ||||
| 	} | ||||
| 
 | ||||
| 	if err := p.setupServer(opts); err != nil { | ||||
| 		return nil, fmt.Errorf("error setting up server: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return p, nil | ||||
| } | ||||
| 
 | ||||
| func (p *OAuthProxy) Start() error { | ||||
| 	if p.server == nil { | ||||
| 		// We have to call setupServer before Start is called.
 | ||||
| 		// If this doesn't happen it's a programming error.
 | ||||
| 		panic("server has not been initialised") | ||||
| 	} | ||||
| 
 | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 
 | ||||
| 	// Observe signals in background goroutine.
 | ||||
| 	go func() { | ||||
| 		sigint := make(chan os.Signal, 1) | ||||
| 		signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) | ||||
| 		<-sigint | ||||
| 		cancel() // cancel the context
 | ||||
| 	}() | ||||
| 
 | ||||
| 	return p.server.Start(ctx) | ||||
| } | ||||
| 
 | ||||
| func (p *OAuthProxy) setupServer(opts *options.Options) error { | ||||
| 	serverOpts := proxyhttp.Opts{ | ||||
| 		Handler:           p, | ||||
| 		BindAddress:       opts.Server.BindAddress, | ||||
| 		SecureBindAddress: opts.Server.SecureBindAddress, | ||||
| 		TLS:               opts.Server.TLS, | ||||
| 	} | ||||
| 
 | ||||
| 	appServer, err := proxyhttp.NewServer(serverOpts) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("could not build app server: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	metricsServer, err := proxyhttp.NewServer(proxyhttp.Opts{ | ||||
| 		Handler:           middleware.DefaultMetricsHandler, | ||||
| 		BindAddress:       opts.MetricsServer.BindAddress, | ||||
| 		SecureBindAddress: opts.MetricsServer.BindAddress, | ||||
| 		TLS:               opts.MetricsServer.TLS, | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("could not build metrics server: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	p.server = proxyhttp.NewServerGroup(appServer, metricsServer) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // buildPreAuthChain constructs a chain that should process every request before
 | ||||
|  | @ -233,9 +291,9 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { | |||
| 	chain := alice.New(middleware.NewScope(opts.ReverseProxy)) | ||||
| 
 | ||||
| 	if opts.ForceHTTPS { | ||||
| 		_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress) | ||||
| 		_, httpsPort, err := net.SplitHostPort(opts.Server.SecureBindAddress) | ||||
| 		if err != nil { | ||||
| 			return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.HTTPAddress, err) | ||||
| 			return alice.Chain{}, fmt.Errorf("invalid HTTPS address %q: %v", opts.Server.SecureBindAddress, err) | ||||
| 		} | ||||
| 		chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort)) | ||||
| 	} | ||||
|  |  | |||
|  | @ -2341,6 +2341,7 @@ func baseTestOptions() *options.Options { | |||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	return opts | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -28,6 +28,18 @@ type AlphaOptions struct { | |||
| 	// Headers may source values from either the authenticated user's session
 | ||||
| 	// or from a static secret value.
 | ||||
| 	InjectResponseHeaders []Header `json:"injectResponseHeaders,omitempty"` | ||||
| 
 | ||||
| 	// Server is used to configure the HTTP(S) server for the proxy application.
 | ||||
| 	// You may choose to run both HTTP and HTTPS servers simultaneously.
 | ||||
| 	// This can be done by setting the BindAddress and the SecureBindAddress simultaneously.
 | ||||
| 	// To use the secure server you must configure a TLS certificate and key.
 | ||||
| 	Server Server `json:"server,omitempty"` | ||||
| 
 | ||||
| 	// MetricsServer is used to configure the HTTP(S) server for metrics.
 | ||||
| 	// You may choose to run both HTTP and HTTPS servers simultaneously.
 | ||||
| 	// This can be done by setting the BindAddress and the SecureBindAddress simultaneously.
 | ||||
| 	// To use the secure server you must configure a TLS certificate and key.
 | ||||
| 	MetricsServer Server `json:"metricsServer,omitempty"` | ||||
| } | ||||
| 
 | ||||
| // MergeInto replaces alpha options in the Options struct with the values
 | ||||
|  | @ -36,6 +48,8 @@ func (a *AlphaOptions) MergeInto(opts *Options) { | |||
| 	opts.UpstreamServers = a.Upstreams | ||||
| 	opts.InjectRequestHeaders = a.InjectRequestHeaders | ||||
| 	opts.InjectResponseHeaders = a.InjectResponseHeaders | ||||
| 	opts.Server = a.Server | ||||
| 	opts.MetricsServer = a.MetricsServer | ||||
| } | ||||
| 
 | ||||
| // ExtractFrom populates the fields in the AlphaOptions with the values from
 | ||||
|  | @ -44,4 +58,6 @@ func (a *AlphaOptions) ExtractFrom(opts *Options) { | |||
| 	a.Upstreams = opts.UpstreamServers | ||||
| 	a.InjectRequestHeaders = opts.InjectRequestHeaders | ||||
| 	a.InjectResponseHeaders = opts.InjectResponseHeaders | ||||
| 	a.Server = opts.Server | ||||
| 	a.MetricsServer = opts.MetricsServer | ||||
| } | ||||
|  |  | |||
|  | @ -18,6 +18,9 @@ type LegacyOptions struct { | |||
| 	// Legacy options for injecting request/response headers
 | ||||
| 	LegacyHeaders LegacyHeaders `cfg:",squash"` | ||||
| 
 | ||||
| 	// Legacy options for the server address and TLS
 | ||||
| 	LegacyServer LegacyServer `cfg:",squash"` | ||||
| 
 | ||||
| 	Options Options `cfg:",squash"` | ||||
| } | ||||
| 
 | ||||
|  | @ -35,6 +38,11 @@ func NewLegacyOptions() *LegacyOptions { | |||
| 			SkipAuthStripHeaders: true, | ||||
| 		}, | ||||
| 
 | ||||
| 		LegacyServer: LegacyServer{ | ||||
| 			HTTPAddress:  "127.0.0.1:4180", | ||||
| 			HTTPSAddress: ":443", | ||||
| 		}, | ||||
| 
 | ||||
| 		Options: *NewOptions(), | ||||
| 	} | ||||
| } | ||||
|  | @ -44,6 +52,7 @@ func NewLegacyFlagSet() *pflag.FlagSet { | |||
| 
 | ||||
| 	flagSet.AddFlagSet(legacyUpstreamsFlagSet()) | ||||
| 	flagSet.AddFlagSet(legacyHeadersFlagSet()) | ||||
| 	flagSet.AddFlagSet(legacyServerFlagset()) | ||||
| 
 | ||||
| 	return flagSet | ||||
| } | ||||
|  | @ -56,6 +65,8 @@ func (l *LegacyOptions) ToOptions() (*Options, error) { | |||
| 	l.Options.UpstreamServers = upstreams | ||||
| 
 | ||||
| 	l.Options.InjectRequestHeaders, l.Options.InjectResponseHeaders = l.LegacyHeaders.convert() | ||||
| 	l.Options.Server, l.Options.MetricsServer = l.LegacyServer.convert() | ||||
| 
 | ||||
| 	return &l.Options, nil | ||||
| } | ||||
| 
 | ||||
|  | @ -403,3 +414,69 @@ func getXAuthRequestAccessTokenHeader() Header { | |||
| 		}, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type LegacyServer struct { | ||||
| 	MetricsAddress       string `flag:"metrics-address" cfg:"metrics_address"` | ||||
| 	MetricsSecureAddress string `flag:"metrics-secure-address" cfg:"metrics_address"` | ||||
| 	MetricsTLSCertFile   string `flag:"metrics-tls-cert-file" cfg:"tls_cert_file"` | ||||
| 	MetricsTLSKeyFile    string `flag:"metrics-tls-key-file" cfg:"tls_key_file"` | ||||
| 	HTTPAddress          string `flag:"http-address" cfg:"http_address"` | ||||
| 	HTTPSAddress         string `flag:"https-address" cfg:"https_address"` | ||||
| 	TLSCertFile          string `flag:"tls-cert-file" cfg:"tls_cert_file"` | ||||
| 	TLSKeyFile           string `flag:"tls-key-file" cfg:"tls_key_file"` | ||||
| } | ||||
| 
 | ||||
| func legacyServerFlagset() *pflag.FlagSet { | ||||
| 	flagSet := pflag.NewFlagSet("server", pflag.ExitOnError) | ||||
| 
 | ||||
| 	flagSet.String("metrics-address", "", "the address /metrics will be served on (e.g. \":9100\")") | ||||
| 	flagSet.String("metrics-secure-address", "", "the address /metrics will be served on for HTTPS clients (e.g. \":9100\")") | ||||
| 	flagSet.String("metrics-tls-cert-file", "", "path to certificate file for secure metrics server") | ||||
| 	flagSet.String("metrics-tls-key-file", "", "path to private key file for secure metrics server") | ||||
| 	flagSet.String("http-address", "127.0.0.1:4180", "[http://]<addr>:<port> or unix://<path> to listen on for HTTP clients") | ||||
| 	flagSet.String("https-address", ":443", "<addr>:<port> to listen on for HTTPS clients") | ||||
| 	flagSet.String("tls-cert-file", "", "path to certificate file") | ||||
| 	flagSet.String("tls-key-file", "", "path to private key file") | ||||
| 
 | ||||
| 	return flagSet | ||||
| } | ||||
| 
 | ||||
| func (l LegacyServer) convert() (Server, Server) { | ||||
| 	appServer := Server{ | ||||
| 		BindAddress:       l.HTTPAddress, | ||||
| 		SecureBindAddress: l.HTTPSAddress, | ||||
| 	} | ||||
| 	if l.TLSKeyFile != "" || l.TLSCertFile != "" { | ||||
| 		appServer.TLS = &TLS{ | ||||
| 			Key: &SecretSource{ | ||||
| 				FromFile: l.TLSKeyFile, | ||||
| 			}, | ||||
| 			Cert: &SecretSource{ | ||||
| 				FromFile: l.TLSCertFile, | ||||
| 			}, | ||||
| 		} | ||||
| 		// Preserve backwards compatibility, only run one server
 | ||||
| 		appServer.BindAddress = "" | ||||
| 	} else { | ||||
| 		// Disable the HTTPS server if there's no certificates.
 | ||||
| 		// This preserves backwards compatibility.
 | ||||
| 		appServer.SecureBindAddress = "" | ||||
| 	} | ||||
| 
 | ||||
| 	metricsServer := Server{ | ||||
| 		BindAddress:       l.MetricsAddress, | ||||
| 		SecureBindAddress: l.MetricsSecureAddress, | ||||
| 	} | ||||
| 	if l.MetricsTLSKeyFile != "" || l.MetricsTLSCertFile != "" { | ||||
| 		metricsServer.TLS = &TLS{ | ||||
| 			Key: &SecretSource{ | ||||
| 				FromFile: l.MetricsTLSKeyFile, | ||||
| 			}, | ||||
| 			Cert: &SecretSource{ | ||||
| 				FromFile: l.MetricsTLSCertFile, | ||||
| 			}, | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return appServer, metricsServer | ||||
| } | ||||
|  |  | |||
|  | @ -106,6 +106,10 @@ var _ = Describe("Legacy Options", func() { | |||
| 
 | ||||
| 			opts.InjectResponseHeaders = []Header{} | ||||
| 
 | ||||
| 			opts.Server = Server{ | ||||
| 				BindAddress: "127.0.0.1:4180", | ||||
| 			} | ||||
| 
 | ||||
| 			converted, err := legacyOpts.ToOptions() | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 			Expect(converted).To(Equal(opts)) | ||||
|  | @ -759,4 +763,93 @@ var _ = Describe("Legacy Options", func() { | |||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("Legacy Servers", func() { | ||||
| 		type legacyServersTableInput struct { | ||||
| 			legacyServer          LegacyServer | ||||
| 			expectedAppServer     Server | ||||
| 			expectedMetricsServer Server | ||||
| 		} | ||||
| 
 | ||||
| 		const ( | ||||
| 			insecureAddr        = "127.0.0.1:8080" | ||||
| 			insecureMetricsAddr = ":9090" | ||||
| 			secureAddr          = ":443" | ||||
| 			secureMetricsAddr   = ":9443" | ||||
| 			crtPath             = "tls.crt" | ||||
| 			keyPath             = "tls.key" | ||||
| 		) | ||||
| 
 | ||||
| 		var tlsConfig = &TLS{ | ||||
| 			Cert: &SecretSource{ | ||||
| 				FromFile: crtPath, | ||||
| 			}, | ||||
| 			Key: &SecretSource{ | ||||
| 				FromFile: keyPath, | ||||
| 			}, | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("should convert to app and metrics servers", | ||||
| 			func(in legacyServersTableInput) { | ||||
| 				appServer, metricsServer := in.legacyServer.convert() | ||||
| 				Expect(appServer).To(Equal(in.expectedAppServer)) | ||||
| 				Expect(metricsServer).To(Equal(in.expectedMetricsServer)) | ||||
| 			}, | ||||
| 			Entry("with default options only starts app HTTP server", legacyServersTableInput{ | ||||
| 				legacyServer: LegacyServer{ | ||||
| 					HTTPAddress:  insecureAddr, | ||||
| 					HTTPSAddress: secureAddr, | ||||
| 				}, | ||||
| 				expectedAppServer: Server{ | ||||
| 					BindAddress: insecureAddr, | ||||
| 				}, | ||||
| 			}), | ||||
| 			Entry("with TLS options specified only starts app HTTPS server", legacyServersTableInput{ | ||||
| 				legacyServer: LegacyServer{ | ||||
| 					HTTPAddress:  insecureAddr, | ||||
| 					HTTPSAddress: secureAddr, | ||||
| 					TLSKeyFile:   keyPath, | ||||
| 					TLSCertFile:  crtPath, | ||||
| 				}, | ||||
| 				expectedAppServer: Server{ | ||||
| 					SecureBindAddress: secureAddr, | ||||
| 					TLS:               tlsConfig, | ||||
| 				}, | ||||
| 			}), | ||||
| 			Entry("with metrics HTTP and HTTPS addresses", legacyServersTableInput{ | ||||
| 				legacyServer: LegacyServer{ | ||||
| 					HTTPAddress:          insecureAddr, | ||||
| 					HTTPSAddress:         secureAddr, | ||||
| 					MetricsAddress:       insecureMetricsAddr, | ||||
| 					MetricsSecureAddress: secureMetricsAddr, | ||||
| 				}, | ||||
| 				expectedAppServer: Server{ | ||||
| 					BindAddress: insecureAddr, | ||||
| 				}, | ||||
| 				expectedMetricsServer: Server{ | ||||
| 					BindAddress:       insecureMetricsAddr, | ||||
| 					SecureBindAddress: secureMetricsAddr, | ||||
| 				}, | ||||
| 			}), | ||||
| 			Entry("with metrics HTTPS and tls cert/key", legacyServersTableInput{ | ||||
| 				legacyServer: LegacyServer{ | ||||
| 					HTTPAddress:          insecureAddr, | ||||
| 					HTTPSAddress:         secureAddr, | ||||
| 					MetricsAddress:       insecureMetricsAddr, | ||||
| 					MetricsSecureAddress: secureMetricsAddr, | ||||
| 					MetricsTLSKeyFile:    keyPath, | ||||
| 					MetricsTLSCertFile:   crtPath, | ||||
| 				}, | ||||
| 				expectedAppServer: Server{ | ||||
| 					BindAddress: insecureAddr, | ||||
| 				}, | ||||
| 				expectedMetricsServer: Server{ | ||||
| 					BindAddress:       insecureMetricsAddr, | ||||
| 					SecureBindAddress: secureMetricsAddr, | ||||
| 					TLS:               tlsConfig, | ||||
| 				}, | ||||
| 			}), | ||||
| 		) | ||||
| 
 | ||||
| 	}) | ||||
| }) | ||||
|  |  | |||
|  | @ -22,9 +22,6 @@ type Options struct { | |||
| 	ProxyPrefix        string   `flag:"proxy-prefix" cfg:"proxy_prefix"` | ||||
| 	PingPath           string   `flag:"ping-path" cfg:"ping_path"` | ||||
| 	PingUserAgent      string   `flag:"ping-user-agent" cfg:"ping_user_agent"` | ||||
| 	MetricsAddress     string   `flag:"metrics-address" cfg:"metrics_address"` | ||||
| 	HTTPAddress        string   `flag:"http-address" cfg:"http_address"` | ||||
| 	HTTPSAddress       string   `flag:"https-address" cfg:"https_address"` | ||||
| 	ReverseProxy       bool     `flag:"reverse-proxy" cfg:"reverse_proxy"` | ||||
| 	RealClientIPHeader string   `flag:"real-client-ip-header" cfg:"real_client_ip_header"` | ||||
| 	TrustedIPs         []string `flag:"trusted-ip" cfg:"trusted_ips"` | ||||
|  | @ -33,8 +30,6 @@ type Options struct { | |||
| 	ClientID           string   `flag:"client-id" cfg:"client_id"` | ||||
| 	ClientSecret       string   `flag:"client-secret" cfg:"client_secret"` | ||||
| 	ClientSecretFile   string   `flag:"client-secret-file" cfg:"client_secret_file"` | ||||
| 	TLSCertFile        string   `flag:"tls-cert-file" cfg:"tls_cert_file"` | ||||
| 	TLSKeyFile         string   `flag:"tls-key-file" cfg:"tls_key_file"` | ||||
| 
 | ||||
| 	AuthenticatedEmailsFile  string   `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` | ||||
| 	KeycloakGroups           []string `flag:"keycloak-group" cfg:"keycloak_groups"` | ||||
|  | @ -68,6 +63,9 @@ type Options struct { | |||
| 	InjectRequestHeaders  []Header `cfg:",internal"` | ||||
| 	InjectResponseHeaders []Header `cfg:",internal"` | ||||
| 
 | ||||
| 	Server        Server `cfg:",internal"` | ||||
| 	MetricsServer Server `cfg:",internal"` | ||||
| 
 | ||||
| 	SkipAuthRegex         []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` | ||||
| 	SkipAuthRoutes        []string `flag:"skip-auth-route" cfg:"skip_auth_routes"` | ||||
| 	SkipJwtBearerTokens   bool     `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens"` | ||||
|  | @ -136,10 +134,7 @@ func NewOptions() *Options { | |||
| 	return &Options{ | ||||
| 		ProxyPrefix:                      "/oauth2", | ||||
| 		ProviderType:                     "google", | ||||
| 		MetricsAddress:                   "", | ||||
| 		PingPath:                         "/ping", | ||||
| 		HTTPAddress:                      "127.0.0.1:4180", | ||||
| 		HTTPSAddress:                     ":443", | ||||
| 		RealClientIPHeader:               "X-Real-IP", | ||||
| 		ForceHTTPS:                       false, | ||||
| 		Cookie:                           cookieDefaults(), | ||||
|  | @ -162,14 +157,10 @@ func NewOptions() *Options { | |||
| func NewFlagSet() *pflag.FlagSet { | ||||
| 	flagSet := pflag.NewFlagSet("oauth2-proxy", pflag.ExitOnError) | ||||
| 
 | ||||
| 	flagSet.String("http-address", "127.0.0.1:4180", "[http://]<addr>:<port> or unix://<path> to listen on for HTTP clients") | ||||
| 	flagSet.String("https-address", ":443", "<addr>:<port> to listen on for HTTPS clients") | ||||
| 	flagSet.Bool("reverse-proxy", false, "are we running behind a reverse proxy, controls whether headers like X-Real-Ip are accepted") | ||||
| 	flagSet.String("real-client-ip-header", "X-Real-IP", "Header used to determine the real IP of the client (one of: X-Forwarded-For, X-Real-IP, or X-ProxyUser-IP)") | ||||
| 	flagSet.StringSlice("trusted-ip", []string{}, "list of IPs or CIDR ranges to allow to bypass authentication. WARNING: trusting by IP has inherent security flaws, read the configuration documentation for more information.") | ||||
| 	flagSet.Bool("force-https", false, "force HTTPS redirect for HTTP requests") | ||||
| 	flagSet.String("tls-cert-file", "", "path to certificate file") | ||||
| 	flagSet.String("tls-key-file", "", "path to private key file") | ||||
| 	flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"") | ||||
| 	flagSet.StringSlice("skip-auth-regex", []string{}, "(DEPRECATED for --skip-auth-route) bypass authentication for requests path's that match (may be given multiple times)") | ||||
| 	flagSet.StringSlice("skip-auth-route", []string{}, "bypass authentication for requests that match the method & path. Format: method=path_regex OR path_regex alone for all methods") | ||||
|  | @ -204,7 +195,6 @@ func NewFlagSet() *pflag.FlagSet { | |||
| 	flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. /<oauth2>/sign_in)") | ||||
| 	flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks") | ||||
| 	flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks") | ||||
| 	flagSet.String("metrics-address", "", "the address /metrics will be served on (e.g. \":9100\")") | ||||
| 	flagSet.String("session-store-type", "cookie", "the session storage provider to use") | ||||
| 	flagSet.Bool("session-cookie-minimal", false, "strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only)") | ||||
| 	flagSet.String("redis-connection-url", "", "URL of redis server for redis session storage (eg: redis://HOST[:PORT])") | ||||
|  |  | |||
|  | @ -0,0 +1,27 @@ | |||
| package options | ||||
| 
 | ||||
| // Server represents the configuration for an HTTP(S) server
 | ||||
| type Server struct { | ||||
| 	// BindAddress is the the address on which to serve traffic.
 | ||||
| 	// Leave blank or set to "-" to disable.
 | ||||
| 	BindAddress string | ||||
| 
 | ||||
| 	// SecureBindAddress is the the address on which to serve secure traffic.
 | ||||
| 	// Leave blank or set to "-" to disable.
 | ||||
| 	SecureBindAddress string | ||||
| 
 | ||||
| 	// TLS contains the information for loading the certificate and key for the
 | ||||
| 	// secure traffic.
 | ||||
| 	TLS *TLS | ||||
| } | ||||
| 
 | ||||
| // TLS contains the information for loading a TLS certifcate and key.
 | ||||
| type TLS struct { | ||||
| 	// Key is the the TLS key data to use.
 | ||||
| 	// Typically this will come from a file.
 | ||||
| 	Key *SecretSource | ||||
| 
 | ||||
| 	// Cert is the TLS certificate data to use.
 | ||||
| 	// Typically this will come from a file.
 | ||||
| 	Cert *SecretSource | ||||
| } | ||||
|  | @ -0,0 +1,88 @@ | |||
| package http | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"crypto/rand" | ||||
| 	"crypto/rsa" | ||||
| 	"crypto/tls" | ||||
| 	"crypto/x509" | ||||
| 	"crypto/x509/pkix" | ||||
| 	"encoding/pem" | ||||
| 	"math/big" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| var certData []byte | ||||
| var certDataSource, keyDataSource options.SecretSource | ||||
| var client *http.Client | ||||
| 
 | ||||
| func TestHTTPSuite(t *testing.T) { | ||||
| 	logger.SetOutput(GinkgoWriter) | ||||
| 	logger.SetErrOutput(GinkgoWriter) | ||||
| 
 | ||||
| 	RegisterFailHandler(Fail) | ||||
| 	RunSpecs(t, "HTTP") | ||||
| } | ||||
| 
 | ||||
| var _ = BeforeSuite(func() { | ||||
| 	By("Generating a self-signed cert for TLS tests", func() { | ||||
| 		priv, err := rsa.GenerateKey(rand.Reader, 2048) | ||||
| 		Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 		keyOut := bytes.NewBuffer(nil) | ||||
| 		privBytes, err := x509.MarshalPKCS8PrivateKey(priv) | ||||
| 		Expect(err).ToNot(HaveOccurred()) | ||||
| 		Expect(pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})).To(Succeed()) | ||||
| 		keyDataSource.Value = keyOut.Bytes() | ||||
| 
 | ||||
| 		serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) | ||||
| 		serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) | ||||
| 		Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 		template := x509.Certificate{ | ||||
| 			SerialNumber: serialNumber, | ||||
| 			Subject: pkix.Name{ | ||||
| 				Organization: []string{"OAuth2 Proxy Test Suite"}, | ||||
| 			}, | ||||
| 			NotBefore:   time.Now(), | ||||
| 			NotAfter:    time.Now().Add(time.Hour), | ||||
| 			IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, | ||||
| 			KeyUsage:    x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, | ||||
| 			ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, | ||||
| 		} | ||||
| 
 | ||||
| 		certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) | ||||
| 		Expect(err).ToNot(HaveOccurred()) | ||||
| 		certData = certBytes | ||||
| 
 | ||||
| 		certOut := bytes.NewBuffer(nil) | ||||
| 		Expect(pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})).To(Succeed()) | ||||
| 		certDataSource.Value = certOut.Bytes() | ||||
| 	}) | ||||
| 
 | ||||
| 	By("Setting up a http client", func() { | ||||
| 		cert, err := tls.X509KeyPair(certDataSource.Value, keyDataSource.Value) | ||||
| 		Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 		certificate, err := x509.ParseCertificate(cert.Certificate[0]) | ||||
| 		Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 		certpool := x509.NewCertPool() | ||||
| 		certpool.AddCert(certificate) | ||||
| 
 | ||||
| 		transport := http.DefaultTransport.(*http.Transport).Clone() | ||||
| 		transport.TLSClientConfig.RootCAs = certpool | ||||
| 
 | ||||
| 		client = &http.Client{ | ||||
| 			Transport: transport, | ||||
| 		} | ||||
| 	}) | ||||
| }) | ||||
|  | @ -0,0 +1,245 @@ | |||
| package http | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options/util" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||
| 	"golang.org/x/sync/errgroup" | ||||
| ) | ||||
| 
 | ||||
| // Server represents an HTTP or HTTPS server.
 | ||||
| type Server interface { | ||||
| 	// Start blocks and runs the server.
 | ||||
| 	Start(ctx context.Context) error | ||||
| } | ||||
| 
 | ||||
| // Opts contains the information required to set up the server.
 | ||||
| type Opts struct { | ||||
| 	// Handler is the http.Handler to be used to serve http pages by the server.
 | ||||
| 	Handler http.Handler | ||||
| 
 | ||||
| 	// BindAddress is the address the HTTP server should listen on.
 | ||||
| 	BindAddress string | ||||
| 
 | ||||
| 	// SecureBindAddress is the address the HTTPS server should listen on.
 | ||||
| 	SecureBindAddress string | ||||
| 
 | ||||
| 	// TLS is the TLS configuration for the server.
 | ||||
| 	TLS *options.TLS | ||||
| } | ||||
| 
 | ||||
| // NewServer creates a new Server from the options given.
 | ||||
| func NewServer(opts Opts) (Server, error) { | ||||
| 	s := &server{ | ||||
| 		handler: opts.Handler, | ||||
| 	} | ||||
| 	if err := s.setupListener(opts); err != nil { | ||||
| 		return nil, fmt.Errorf("error setting up listener: %v", err) | ||||
| 	} | ||||
| 	if err := s.setupTLSListener(opts); err != nil { | ||||
| 		return nil, fmt.Errorf("error setting up TLS listener: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return s, nil | ||||
| } | ||||
| 
 | ||||
| // server is an implementation of the Server interface.
 | ||||
| type server struct { | ||||
| 	handler http.Handler | ||||
| 
 | ||||
| 	listener    net.Listener | ||||
| 	tlsListener net.Listener | ||||
| } | ||||
| 
 | ||||
| // setupListener sets the server listener if the HTTP server is enabled.
 | ||||
| // The HTTP server can be disabled by setting the BindAddress to "-" or by
 | ||||
| // leaving it empty.
 | ||||
| func (s *server) setupListener(opts Opts) error { | ||||
| 	if opts.BindAddress == "" || opts.BindAddress == "-" { | ||||
| 		// No HTTP listener required
 | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	networkType := getNetworkScheme(opts.BindAddress) | ||||
| 	listenAddr := getListenAddress(opts.BindAddress) | ||||
| 
 | ||||
| 	listener, err := net.Listen(networkType, listenAddr) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("listen (%s, %s) failed: %v", networkType, listenAddr, err) | ||||
| 	} | ||||
| 	s.listener = listener | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // setupTLSListener sets the server TLS listener if the HTTPS server is enabled.
 | ||||
| // The HTTPS server can be disabled by setting the SecureBindAddress to "-" or by
 | ||||
| // leaving it empty.
 | ||||
| func (s *server) setupTLSListener(opts Opts) error { | ||||
| 	if opts.SecureBindAddress == "" || opts.SecureBindAddress == "-" { | ||||
| 		// No HTTPS listener required
 | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	config := &tls.Config{ | ||||
| 		MinVersion: tls.VersionTLS12, | ||||
| 		MaxVersion: tls.VersionTLS13, | ||||
| 		NextProtos: []string{"http/1.1"}, | ||||
| 	} | ||||
| 	if opts.TLS == nil { | ||||
| 		return errors.New("no TLS config provided") | ||||
| 	} | ||||
| 	cert, err := getCertificate(opts.TLS) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("could not load certificate: %v", err) | ||||
| 	} | ||||
| 	config.Certificates = []tls.Certificate{cert} | ||||
| 
 | ||||
| 	listenAddr := getListenAddress(opts.SecureBindAddress) | ||||
| 
 | ||||
| 	listener, err := net.Listen("tcp", listenAddr) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("listen (%s) failed: %v", listenAddr, err) | ||||
| 	} | ||||
| 
 | ||||
| 	s.tlsListener = tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, config) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Start starts the HTTP and HTTPS server if applicable.
 | ||||
| // It will block until the context is cancelled.
 | ||||
| // If any errors occur, only the first error will be returned.
 | ||||
| func (s *server) Start(ctx context.Context) error { | ||||
| 	g, groupCtx := errgroup.WithContext(ctx) | ||||
| 
 | ||||
| 	if s.listener != nil { | ||||
| 		g.Go(func() error { | ||||
| 			if err := s.startServer(groupCtx, s.listener); err != nil { | ||||
| 				return fmt.Errorf("error starting insecure server: %v", err) | ||||
| 			} | ||||
| 			return nil | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| 	if s.tlsListener != nil { | ||||
| 		g.Go(func() error { | ||||
| 			if err := s.startServer(groupCtx, s.tlsListener); err != nil { | ||||
| 				return fmt.Errorf("error starting secure server: %v", err) | ||||
| 			} | ||||
| 			return nil | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| 	return g.Wait() | ||||
| } | ||||
| 
 | ||||
| // startServer creates and starts a new server with the given listener.
 | ||||
| // When the given context is cancelled the server will be shutdown.
 | ||||
| // If any errors occur, only the first error will be returned.
 | ||||
| func (s *server) startServer(ctx context.Context, listener net.Listener) error { | ||||
| 	srv := &http.Server{Handler: s.handler} | ||||
| 	g, groupCtx := errgroup.WithContext(ctx) | ||||
| 
 | ||||
| 	g.Go(func() error { | ||||
| 		<-groupCtx.Done() | ||||
| 
 | ||||
| 		if err := srv.Shutdown(context.Background()); err != nil { | ||||
| 			return fmt.Errorf("error shutting down server: %v", err) | ||||
| 		} | ||||
| 		return nil | ||||
| 	}) | ||||
| 
 | ||||
| 	g.Go(func() error { | ||||
| 		if err := srv.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { | ||||
| 			return fmt.Errorf("could not start server: %v", err) | ||||
| 		} | ||||
| 		return nil | ||||
| 	}) | ||||
| 
 | ||||
| 	return g.Wait() | ||||
| } | ||||
| 
 | ||||
| // getNetworkScheme gets the scheme for the HTTP server.
 | ||||
| func getNetworkScheme(addr string) string { | ||||
| 	var scheme string | ||||
| 	i := strings.Index(addr, "://") | ||||
| 	if i > -1 { | ||||
| 		scheme = addr[0:i] | ||||
| 	} | ||||
| 
 | ||||
| 	switch scheme { | ||||
| 	case "", "http": | ||||
| 		return "tcp" | ||||
| 	default: | ||||
| 		return scheme | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // getListenAddress gets the address for the HTTP server.
 | ||||
| func getListenAddress(addr string) string { | ||||
| 	slice := strings.SplitN(addr, "//", 2) | ||||
| 	return slice[len(slice)-1] | ||||
| } | ||||
| 
 | ||||
| // getCertificate loads the certificate data from the TLS config.
 | ||||
| func getCertificate(opts *options.TLS) (tls.Certificate, error) { | ||||
| 	keyData, err := getSecretValue(opts.Key) | ||||
| 	if err != nil { | ||||
| 		return tls.Certificate{}, fmt.Errorf("could not load key data: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	certData, err := getSecretValue(opts.Cert) | ||||
| 	if err != nil { | ||||
| 		return tls.Certificate{}, fmt.Errorf("could not load cert data: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	cert, err := tls.X509KeyPair(certData, keyData) | ||||
| 	if err != nil { | ||||
| 		return tls.Certificate{}, fmt.Errorf("could not parse certificate data: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return cert, nil | ||||
| } | ||||
| 
 | ||||
| // getSecretValue wraps util.GetSecretValue so that we can return an error if no
 | ||||
| // source is provided.
 | ||||
| func getSecretValue(src *options.SecretSource) ([]byte, error) { | ||||
| 	if src == nil { | ||||
| 		return nil, errors.New("no configuration provided") | ||||
| 	} | ||||
| 	return util.GetSecretValue(src) | ||||
| } | ||||
| 
 | ||||
| // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
 | ||||
| // connections. It's used by so that dead TCP connections (e.g. closing laptop
 | ||||
| // mid-download) eventually go away.
 | ||||
| type tcpKeepAliveListener struct { | ||||
| 	*net.TCPListener | ||||
| } | ||||
| 
 | ||||
| // Accept implements the TCPListener interface.
 | ||||
| // It sets the keep alive period to 3 minutes for each connection.
 | ||||
| func (ln tcpKeepAliveListener) Accept() (net.Conn, error) { | ||||
| 	tc, err := ln.AcceptTCP() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	err = tc.SetKeepAlive(true) | ||||
| 	if err != nil { | ||||
| 		logger.Errorf("Error setting Keep-Alive: %v", err) | ||||
| 	} | ||||
| 	err = tc.SetKeepAlivePeriod(3 * time.Minute) | ||||
| 	if err != nil { | ||||
| 		logger.Printf("Error setting Keep-Alive period: %v", err) | ||||
| 	} | ||||
| 	return tc, nil | ||||
| } | ||||
|  | @ -0,0 +1,35 @@ | |||
| package http | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 
 | ||||
| 	"golang.org/x/sync/errgroup" | ||||
| ) | ||||
| 
 | ||||
| // NewServerGroup creates a new Server to start and gracefully stop a collection
 | ||||
| // of Servers.
 | ||||
| func NewServerGroup(servers ...Server) Server { | ||||
| 	return &serverGroup{ | ||||
| 		servers: servers, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // serverGroup manages the starting and graceful shutdown of a collection of
 | ||||
| // servers.
 | ||||
| type serverGroup struct { | ||||
| 	servers []Server | ||||
| } | ||||
| 
 | ||||
| // Start runs the servers in the server group.
 | ||||
| func (s *serverGroup) Start(ctx context.Context) error { | ||||
| 	g, groupCtx := errgroup.WithContext(ctx) | ||||
| 
 | ||||
| 	for _, server := range s.servers { | ||||
| 		srv := server | ||||
| 		g.Go(func() error { | ||||
| 			return srv.Start(groupCtx) | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| 	return g.Wait() | ||||
| } | ||||
|  | @ -0,0 +1,102 @@ | |||
| package http | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 
 | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| var _ = Describe("Server Group", func() { | ||||
| 	var m1, m2, m3 *mockServer | ||||
| 	var ctx context.Context | ||||
| 	var cancel context.CancelFunc | ||||
| 	var group Server | ||||
| 
 | ||||
| 	BeforeEach(func() { | ||||
| 		ctx, cancel = context.WithCancel(context.Background()) | ||||
| 
 | ||||
| 		m1 = newMockServer() | ||||
| 		m2 = newMockServer() | ||||
| 		m3 = newMockServer() | ||||
| 		group = NewServerGroup(m1, m2, m3) | ||||
| 	}) | ||||
| 
 | ||||
| 	AfterEach(func() { | ||||
| 		cancel() | ||||
| 	}) | ||||
| 
 | ||||
| 	It("starts each server in the group", func() { | ||||
| 		go func() { | ||||
| 			defer GinkgoRecover() | ||||
| 			Expect(group.Start(ctx)).To(Succeed()) | ||||
| 		}() | ||||
| 
 | ||||
| 		Eventually(m1.started).Should(BeClosed(), "mock server 1 not started") | ||||
| 		Eventually(m2.started).Should(BeClosed(), "mock server 2 not started") | ||||
| 		Eventually(m3.started).Should(BeClosed(), "mock server 3 not started") | ||||
| 	}) | ||||
| 
 | ||||
| 	It("stop each server in the group when the context is cancelled", func() { | ||||
| 		go func() { | ||||
| 			defer GinkgoRecover() | ||||
| 			Expect(group.Start(ctx)).To(Succeed()) | ||||
| 		}() | ||||
| 
 | ||||
| 		cancel() | ||||
| 		Eventually(m1.stopped).Should(BeClosed(), "mock server 1 not stopped") | ||||
| 		Eventually(m2.stopped).Should(BeClosed(), "mock server 2 not stopped") | ||||
| 		Eventually(m3.stopped).Should(BeClosed(), "mock server 3 not stopped") | ||||
| 	}) | ||||
| 
 | ||||
| 	It("stop each server in the group when the an error occurs", func() { | ||||
| 		err := errors.New("server error") | ||||
| 		go func() { | ||||
| 			defer GinkgoRecover() | ||||
| 			Expect(group.Start(ctx)).To(MatchError(err)) | ||||
| 		}() | ||||
| 
 | ||||
| 		m2.errors <- err | ||||
| 		Eventually(m1.stopped).Should(BeClosed(), "mock server 1 not stopped") | ||||
| 		Eventually(m2.stopped).Should(BeClosed(), "mock server 2 not stopped") | ||||
| 		Eventually(m3.stopped).Should(BeClosed(), "mock server 3 not stopped") | ||||
| 	}) | ||||
| }) | ||||
| 
 | ||||
| // mockServer is used to test the server group can start
 | ||||
| // and stop multiple servers simultaneously.
 | ||||
| type mockServer struct { | ||||
| 	started     chan struct{} | ||||
| 	startClosed bool | ||||
| 	stopped     chan struct{} | ||||
| 	stopClosed  bool | ||||
| 	errors      chan error | ||||
| } | ||||
| 
 | ||||
| func newMockServer() *mockServer { | ||||
| 	return &mockServer{ | ||||
| 		started: make(chan struct{}), | ||||
| 		stopped: make(chan struct{}), | ||||
| 		errors:  make(chan error), | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (m *mockServer) Start(ctx context.Context) error { | ||||
| 	if !m.startClosed { | ||||
| 		close(m.started) | ||||
| 		m.startClosed = true | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		if !m.stopClosed { | ||||
| 			close(m.stopped) | ||||
| 			m.stopClosed = true | ||||
| 		} | ||||
| 	}() | ||||
| 	select { | ||||
| 	case <-ctx.Done(): | ||||
| 		return nil | ||||
| 	case err := <-m.errors: | ||||
| 		return err | ||||
| 	} | ||||
| } | ||||
|  | @ -0,0 +1,472 @@ | |||
| package http | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| const hello = "Hello World!" | ||||
| 
 | ||||
| var _ = Describe("Server", func() { | ||||
| 	handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		rw.Write([]byte(hello)) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("NewServer", func() { | ||||
| 		type newServerTableInput struct { | ||||
| 			opts               Opts | ||||
| 			expectedErr        error | ||||
| 			expectHTTPListener bool | ||||
| 			expectTLSListener  bool | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("When creating the new server from the options", func(in *newServerTableInput) { | ||||
| 			srv, err := NewServer(in.opts) | ||||
| 			if in.expectedErr != nil { | ||||
| 				Expect(err).To(MatchError(ContainSubstring(in.expectedErr.Error()))) | ||||
| 				Expect(srv).To(BeNil()) | ||||
| 				return | ||||
| 			} | ||||
| 
 | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			s, ok := srv.(*server) | ||||
| 			Expect(ok).To(BeTrue()) | ||||
| 
 | ||||
| 			Expect(s.listener != nil).To(Equal(in.expectHTTPListener)) | ||||
| 			if in.expectHTTPListener { | ||||
| 				Expect(s.listener.Close()).To(Succeed()) | ||||
| 			} | ||||
| 			Expect(s.tlsListener != nil).To(Equal(in.expectTLSListener)) | ||||
| 			if in.expectTLSListener { | ||||
| 				Expect(s.tlsListener.Close()).To(Succeed()) | ||||
| 			} | ||||
| 		}, | ||||
| 			Entry("with a valid http bind address", &newServerTableInput{ | ||||
| 				opts: Opts{ | ||||
| 					Handler:     handler, | ||||
| 					BindAddress: "127.0.0.1:0", | ||||
| 				}, | ||||
| 				expectedErr:        nil, | ||||
| 				expectHTTPListener: true, | ||||
| 				expectTLSListener:  false, | ||||
| 			}), | ||||
| 			Entry("with a valid https bind address, with no TLS config", &newServerTableInput{ | ||||
| 				opts: Opts{ | ||||
| 					Handler:           handler, | ||||
| 					SecureBindAddress: "127.0.0.1:0", | ||||
| 				}, | ||||
| 				expectedErr:        errors.New("error setting up TLS listener: no TLS config provided"), | ||||
| 				expectHTTPListener: false, | ||||
| 				expectTLSListener:  false, | ||||
| 			}), | ||||
| 			Entry("with a valid https bind address, and valid TLS config", &newServerTableInput{ | ||||
| 				opts: Opts{ | ||||
| 					Handler:           handler, | ||||
| 					SecureBindAddress: "127.0.0.1:0", | ||||
| 					TLS: &options.TLS{ | ||||
| 						Key:  &keyDataSource, | ||||
| 						Cert: &certDataSource, | ||||
| 					}, | ||||
| 				}, | ||||
| 				expectedErr:        nil, | ||||
| 				expectHTTPListener: false, | ||||
| 				expectTLSListener:  true, | ||||
| 			}), | ||||
| 			Entry("with a both a valid http and valid https bind address, and valid TLS config", &newServerTableInput{ | ||||
| 				opts: Opts{ | ||||
| 					Handler:           handler, | ||||
| 					BindAddress:       "127.0.0.1:0", | ||||
| 					SecureBindAddress: "127.0.0.1:0", | ||||
| 					TLS: &options.TLS{ | ||||
| 						Key:  &keyDataSource, | ||||
| 						Cert: &certDataSource, | ||||
| 					}, | ||||
| 				}, | ||||
| 				expectedErr:        nil, | ||||
| 				expectHTTPListener: true, | ||||
| 				expectTLSListener:  true, | ||||
| 			}), | ||||
| 			Entry("with a \"-\" for the bind address", &newServerTableInput{ | ||||
| 				opts: Opts{ | ||||
| 					Handler:     handler, | ||||
| 					BindAddress: "-", | ||||
| 				}, | ||||
| 				expectedErr:        nil, | ||||
| 				expectHTTPListener: false, | ||||
| 				expectTLSListener:  false, | ||||
| 			}), | ||||
| 			Entry("with a \"-\" for the secure bind address", &newServerTableInput{ | ||||
| 				opts: Opts{ | ||||
| 					Handler:           handler, | ||||
| 					SecureBindAddress: "-", | ||||
| 				}, | ||||
| 				expectedErr:        nil, | ||||
| 				expectHTTPListener: false, | ||||
| 				expectTLSListener:  false, | ||||
| 			}), | ||||
| 			Entry("with an invalid bind address scheme", &newServerTableInput{ | ||||
| 				opts: Opts{ | ||||
| 					Handler:     handler, | ||||
| 					BindAddress: "invalid://127.0.0.1:0", | ||||
| 				}, | ||||
| 				expectedErr:        errors.New("error setting up listener: listen (invalid, 127.0.0.1:0) failed: listen invalid: unknown network invalid"), | ||||
| 				expectHTTPListener: false, | ||||
| 				expectTLSListener:  false, | ||||
| 			}), | ||||
| 			Entry("with an invalid secure bind address scheme", &newServerTableInput{ | ||||
| 				opts: Opts{ | ||||
| 					Handler:           handler, | ||||
| 					SecureBindAddress: "invalid://127.0.0.1:0", | ||||
| 					TLS: &options.TLS{ | ||||
| 						Key:  &keyDataSource, | ||||
| 						Cert: &certDataSource, | ||||
| 					}, | ||||
| 				}, | ||||
| 				expectedErr:        nil, | ||||
| 				expectHTTPListener: false, | ||||
| 				expectTLSListener:  true, | ||||
| 			}), | ||||
| 			Entry("with an invalid bind address port", &newServerTableInput{ | ||||
| 				opts: Opts{ | ||||
| 					Handler:     handler, | ||||
| 					BindAddress: "127.0.0.1:a", | ||||
| 				}, | ||||
| 				expectedErr:        errors.New("error setting up listener: listen (tcp, 127.0.0.1:a) failed: listen tcp: lookup tcp/a: "), | ||||
| 				expectHTTPListener: false, | ||||
| 				expectTLSListener:  false, | ||||
| 			}), | ||||
| 			Entry("with an invalid secure bind address port", &newServerTableInput{ | ||||
| 				opts: Opts{ | ||||
| 					Handler:           handler, | ||||
| 					SecureBindAddress: "127.0.0.1:a", | ||||
| 					TLS: &options.TLS{ | ||||
| 						Key:  &keyDataSource, | ||||
| 						Cert: &certDataSource, | ||||
| 					}, | ||||
| 				}, | ||||
| 				expectedErr:        errors.New("error setting up TLS listener: listen (127.0.0.1:a) failed: listen tcp: lookup tcp/a: "), | ||||
| 				expectHTTPListener: false, | ||||
| 				expectTLSListener:  false, | ||||
| 			}), | ||||
| 			Entry("with an invalid TLS key", &newServerTableInput{ | ||||
| 				opts: Opts{ | ||||
| 					Handler:           handler, | ||||
| 					SecureBindAddress: "127.0.0.1:0", | ||||
| 					TLS: &options.TLS{ | ||||
| 						Key: &options.SecretSource{ | ||||
| 							Value: []byte("invalid"), | ||||
| 						}, | ||||
| 						Cert: &certDataSource, | ||||
| 					}, | ||||
| 				}, | ||||
| 				expectedErr:        errors.New("error setting up TLS listener: could not load certificate: could not parse certificate data: tls: failed to find any PEM data in key input"), | ||||
| 				expectHTTPListener: false, | ||||
| 				expectTLSListener:  false, | ||||
| 			}), | ||||
| 			Entry("with an invalid TLS cert", &newServerTableInput{ | ||||
| 				opts: Opts{ | ||||
| 					Handler:           handler, | ||||
| 					SecureBindAddress: "127.0.0.1:0", | ||||
| 					TLS: &options.TLS{ | ||||
| 						Key: &keyDataSource, | ||||
| 						Cert: &options.SecretSource{ | ||||
| 							Value: []byte("invalid"), | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				expectedErr:        errors.New("error setting up TLS listener: could not load certificate: could not parse certificate data: tls: failed to find any PEM data in certificate input"), | ||||
| 				expectHTTPListener: false, | ||||
| 				expectTLSListener:  false, | ||||
| 			}), | ||||
| 			Entry("with no TLS key", &newServerTableInput{ | ||||
| 				opts: Opts{ | ||||
| 					Handler:           handler, | ||||
| 					SecureBindAddress: "127.0.0.1:0", | ||||
| 					TLS: &options.TLS{ | ||||
| 						Cert: &certDataSource, | ||||
| 					}, | ||||
| 				}, | ||||
| 				expectedErr:        errors.New("error setting up TLS listener: could not load certificate: could not load key data: no configuration provided"), | ||||
| 				expectHTTPListener: false, | ||||
| 				expectTLSListener:  false, | ||||
| 			}), | ||||
| 			Entry("with no TLS cert", &newServerTableInput{ | ||||
| 				opts: Opts{ | ||||
| 					Handler:           handler, | ||||
| 					SecureBindAddress: "127.0.0.1:0", | ||||
| 					TLS: &options.TLS{ | ||||
| 						Key: &keyDataSource, | ||||
| 					}, | ||||
| 				}, | ||||
| 				expectedErr:        errors.New("error setting up TLS listener: could not load certificate: could not load cert data: no configuration provided"), | ||||
| 				expectHTTPListener: false, | ||||
| 				expectTLSListener:  false, | ||||
| 			}), | ||||
| 			Entry("when the bind address is prefixed with the http scheme", &newServerTableInput{ | ||||
| 				opts: Opts{ | ||||
| 					Handler:     handler, | ||||
| 					BindAddress: "http://127.0.0.1:0", | ||||
| 				}, | ||||
| 				expectedErr:        nil, | ||||
| 				expectHTTPListener: true, | ||||
| 				expectTLSListener:  false, | ||||
| 			}), | ||||
| 			Entry("when the secure bind address is prefixed with the https scheme", &newServerTableInput{ | ||||
| 				opts: Opts{ | ||||
| 					Handler:           handler, | ||||
| 					SecureBindAddress: "https://127.0.0.1:0", | ||||
| 					TLS: &options.TLS{ | ||||
| 						Key:  &keyDataSource, | ||||
| 						Cert: &certDataSource, | ||||
| 					}, | ||||
| 				}, | ||||
| 				expectedErr:        nil, | ||||
| 				expectHTTPListener: false, | ||||
| 				expectTLSListener:  true, | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("Start", func() { | ||||
| 		var srv Server | ||||
| 		var ctx context.Context | ||||
| 		var cancel context.CancelFunc | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			ctx, cancel = context.WithCancel(context.Background()) | ||||
| 		}) | ||||
| 
 | ||||
| 		AfterEach(func() { | ||||
| 			cancel() | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("with an http server", func() { | ||||
| 			var listenAddr string | ||||
| 
 | ||||
| 			BeforeEach(func() { | ||||
| 				var err error | ||||
| 				srv, err = NewServer(Opts{ | ||||
| 					Handler:     handler, | ||||
| 					BindAddress: "127.0.0.1:0", | ||||
| 				}) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 				s, ok := srv.(*server) | ||||
| 				Expect(ok).To(BeTrue()) | ||||
| 
 | ||||
| 				listenAddr = fmt.Sprintf("http://%s/", s.listener.Addr().String()) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("Starts the server and serves the handler", func() { | ||||
| 				go func() { | ||||
| 					defer GinkgoRecover() | ||||
| 					Expect(srv.Start(ctx)).To(Succeed()) | ||||
| 				}() | ||||
| 
 | ||||
| 				resp, err := client.Get(listenAddr) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(resp.StatusCode).To(Equal(http.StatusOK)) | ||||
| 
 | ||||
| 				body, err := ioutil.ReadAll(resp.Body) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(string(body)).To(Equal(hello)) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("Stops the server when the context is cancelled", func() { | ||||
| 				go func() { | ||||
| 					defer GinkgoRecover() | ||||
| 					Expect(srv.Start(ctx)).To(Succeed()) | ||||
| 				}() | ||||
| 
 | ||||
| 				_, err := client.Get(listenAddr) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 				cancel() | ||||
| 
 | ||||
| 				Eventually(func() error { | ||||
| 					_, err := client.Get(listenAddr) | ||||
| 					return err | ||||
| 				}).Should(HaveOccurred()) | ||||
| 			}) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("with an https server", func() { | ||||
| 			var secureListenAddr string | ||||
| 
 | ||||
| 			BeforeEach(func() { | ||||
| 				var err error | ||||
| 				srv, err = NewServer(Opts{ | ||||
| 					Handler:           handler, | ||||
| 					SecureBindAddress: "127.0.0.1:0", | ||||
| 					TLS: &options.TLS{ | ||||
| 						Key:  &keyDataSource, | ||||
| 						Cert: &certDataSource, | ||||
| 					}, | ||||
| 				}) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 				s, ok := srv.(*server) | ||||
| 				Expect(ok).To(BeTrue()) | ||||
| 
 | ||||
| 				secureListenAddr = fmt.Sprintf("https://%s/", s.tlsListener.Addr().String()) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("Starts the server and serves the handler", func() { | ||||
| 				go func() { | ||||
| 					defer GinkgoRecover() | ||||
| 					Expect(srv.Start(ctx)).To(Succeed()) | ||||
| 				}() | ||||
| 
 | ||||
| 				resp, err := client.Get(secureListenAddr) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(resp.StatusCode).To(Equal(http.StatusOK)) | ||||
| 
 | ||||
| 				body, err := ioutil.ReadAll(resp.Body) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(string(body)).To(Equal(hello)) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("Stops the server when the context is cancelled", func() { | ||||
| 				go func() { | ||||
| 					defer GinkgoRecover() | ||||
| 					Expect(srv.Start(ctx)).To(Succeed()) | ||||
| 				}() | ||||
| 
 | ||||
| 				_, err := client.Get(secureListenAddr) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 				cancel() | ||||
| 
 | ||||
| 				Eventually(func() error { | ||||
| 					_, err := client.Get(secureListenAddr) | ||||
| 					return err | ||||
| 				}).Should(HaveOccurred()) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("Serves the certificate provided", func() { | ||||
| 				go func() { | ||||
| 					defer GinkgoRecover() | ||||
| 					Expect(srv.Start(ctx)).To(Succeed()) | ||||
| 				}() | ||||
| 
 | ||||
| 				resp, err := client.Get(secureListenAddr) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(resp.StatusCode).To(Equal(http.StatusOK)) | ||||
| 
 | ||||
| 				Expect(resp.TLS.VerifiedChains).Should(HaveLen(1)) | ||||
| 				Expect(resp.TLS.VerifiedChains[0]).Should(HaveLen(1)) | ||||
| 				Expect(resp.TLS.VerifiedChains[0][0].Raw).Should(Equal(certData)) | ||||
| 			}) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("with both an http and an https server", func() { | ||||
| 			var listenAddr, secureListenAddr string | ||||
| 
 | ||||
| 			BeforeEach(func() { | ||||
| 				var err error | ||||
| 				srv, err = NewServer(Opts{ | ||||
| 					Handler:           handler, | ||||
| 					BindAddress:       "127.0.0.1:0", | ||||
| 					SecureBindAddress: "127.0.0.1:0", | ||||
| 					TLS: &options.TLS{ | ||||
| 						Key:  &keyDataSource, | ||||
| 						Cert: &certDataSource, | ||||
| 					}, | ||||
| 				}) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 				s, ok := srv.(*server) | ||||
| 				Expect(ok).To(BeTrue()) | ||||
| 
 | ||||
| 				listenAddr = fmt.Sprintf("http://%s/", s.listener.Addr().String()) | ||||
| 				secureListenAddr = fmt.Sprintf("https://%s/", s.tlsListener.Addr().String()) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("Starts the server and serves the handler on http", func() { | ||||
| 				go func() { | ||||
| 					defer GinkgoRecover() | ||||
| 					Expect(srv.Start(ctx)).To(Succeed()) | ||||
| 				}() | ||||
| 
 | ||||
| 				resp, err := client.Get(listenAddr) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(resp.StatusCode).To(Equal(http.StatusOK)) | ||||
| 
 | ||||
| 				body, err := ioutil.ReadAll(resp.Body) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(string(body)).To(Equal(hello)) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("Starts the server and serves the handler on https", func() { | ||||
| 				go func() { | ||||
| 					defer GinkgoRecover() | ||||
| 					Expect(srv.Start(ctx)).To(Succeed()) | ||||
| 				}() | ||||
| 
 | ||||
| 				resp, err := client.Get(secureListenAddr) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(resp.StatusCode).To(Equal(http.StatusOK)) | ||||
| 
 | ||||
| 				body, err := ioutil.ReadAll(resp.Body) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(string(body)).To(Equal(hello)) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("Stops both servers when the context is cancelled", func() { | ||||
| 				go func() { | ||||
| 					defer GinkgoRecover() | ||||
| 					Expect(srv.Start(ctx)).To(Succeed()) | ||||
| 				}() | ||||
| 
 | ||||
| 				_, err := client.Get(listenAddr) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				_, err = client.Get(secureListenAddr) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 				cancel() | ||||
| 
 | ||||
| 				Eventually(func() error { | ||||
| 					_, err := client.Get(listenAddr) | ||||
| 					return err | ||||
| 				}).Should(HaveOccurred()) | ||||
| 				Eventually(func() error { | ||||
| 					_, err := client.Get(secureListenAddr) | ||||
| 					return err | ||||
| 				}).Should(HaveOccurred()) | ||||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("getNetworkScheme", func() { | ||||
| 		DescribeTable("should return the scheme", func(in, expected string) { | ||||
| 			Expect(getNetworkScheme(in)).To(Equal(expected)) | ||||
| 		}, | ||||
| 			Entry("with no scheme", "127.0.0.1:0", "tcp"), | ||||
| 			Entry("with a tcp scheme", "tcp://127.0.0.1:0", "tcp"), | ||||
| 			Entry("with a http scheme", "http://192.168.0.1:1", "tcp"), | ||||
| 			Entry("with a unix scheme", "unix://172.168.16.2:2", "unix"), | ||||
| 			Entry("with a random scheme", "random://10.10.10.10:10", "random"), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("getListenAddress", func() { | ||||
| 		DescribeTable("should remove the scheme", func(in, expected string) { | ||||
| 			Expect(getListenAddress(in)).To(Equal(expected)) | ||||
| 		}, | ||||
| 			Entry("with no scheme", "127.0.0.1:0", "127.0.0.1:0"), | ||||
| 			Entry("with a tcp scheme", "tcp://127.0.0.1:0", "127.0.0.1:0"), | ||||
| 			Entry("with a http scheme", "http://192.168.0.1:1", "192.168.0.1:1"), | ||||
| 			Entry("with a unix scheme", "unix://172.168.16.2:2", "172.168.16.2:2"), | ||||
| 			Entry("with a random scheme", "random://10.10.10.10:10", "10.10.10.10:10"), | ||||
| 		) | ||||
| 	}) | ||||
| }) | ||||
		Loading…
	
		Reference in New Issue