Implements --real-client-ip-header option. (#503)
* Implements -real-client-ip-header option. * The -real-client-ip-header determines what HTTP header is used for determining the "real client IP" of the remote client. * The -real-client-ip-header option supports the following headers: X-Forwarded-For X-ProxyUser-IP and X-Real-IP (default). * Introduces new realClientIPParser interface to allow for multiple polymorphic classes to decide how to determine the real client IP. * TODO: implement the more standard, but more complex `Forwarded` HTTP header. * Corrected order of expected/actual in test cases * Improved error message in getRemoteIP * Add tests for getRemoteIP and getClientString * Add comment explaining splitting of header * Update documentation on -real-client-ip-header w/o -reverse-proxy * Add PR number in changelog. * Fix typo repeated word: "it" Co-Authored-By: Joel Speed <Joel.speed@hotmail.co.uk> * Update extended configuration language * Simplify the language around dependance on -reverse-proxy Co-Authored-By: Joel Speed <Joel.speed@hotmail.co.uk> * Added completions * Reorder real client IP header options * Update CHANGELOG.md * Apply suggestions from code review Co-authored-by: Isabelle COWAN-BERGMAN <Izzette@users.noreply.github.com> Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk> Co-authored-by: Henry Jenkins <henry@henryjenkins.name>
This commit is contained in:
		
							parent
							
								
									d0cfca4b73
								
							
						
					
					
						commit
						111d17efde
					
				|  | @ -35,6 +35,7 @@ | |||
| 
 | ||||
| ## Changes since v5.1.1 | ||||
| 
 | ||||
| - [#503](https://github.com/oauth2-proxy/oauth2-proxy/pull/503) Implements --real-client-ip-header option to select the header from which to obtain a proxied client's IP (@Izzette) | ||||
| - [#529](https://github.com/oauth2-proxy/oauth2-proxy/pull/529) Add local test environments for testing changes and new features (@JoelSpeed) | ||||
| - [#537](https://github.com/oauth2-proxy/oauth2-proxy/pull/537) Drop Fallback to Email if User not set (@JoelSpeed) | ||||
| - [#535](https://github.com/oauth2-proxy/oauth2-proxy/pull/535) Drop support for pre v3.1 cookies (@JoelSpeed) | ||||
|  |  | |||
|  | @ -20,6 +20,10 @@ _oauth2_proxy() { | |||
| 			COMPREPLY=( $(compgen -W "google azure facebook github keycloak gitlab linkedin login.gov digitalocean" -- ${cur}) ) | ||||
| 			return 0 | ||||
| 			;; | ||||
| 		--real-client-ip-header) | ||||
| 			COMPREPLY=( $(compgen -W 'X-Real-IP X-Forwarded-For X-ProxyUser-IP' -- ${cur}) ) | ||||
| 			return 0 | ||||
| 			;; | ||||
| 		-@(http-address|https-address|redirect-url|upstream|basic-auth-password|skip-auth-regex|flush-interval|extra-jwt-issuers|email-domain|whitelist-domain|keycloak-group|azure-tenant|bitbucket-team|bitbucket-repository|github-org|github-team|github-repo|github-token|gitlab-group|google-group|google-admin-email|google-service-account-json|client-id|client_secret|banner|footer|proxy-prefix|ping-path|cookie-name|cookie-secret|cookie-domain|cookie-path|cookie-expire|cookie-refresh|cookie-samesite|redist-sentinel-master-name|redist-sentinel-connection-urls|redist-cluster-connection-urls|logging-max-size|logging-max-age|logging-max-backups|standard-logging-format|request-logging-format|exclude-logging-paths|auth-logging-format|oidc-issuer-url|oidc-jwks-url|login-url|redeem-url|profile-url|resource|validate-url|scope|approval-prompt|signature-key|acr-values|jwt-key|pubjwk-url)) | ||||
| 			return 0 | ||||
| 			;; | ||||
|  |  | |||
|  | @ -90,6 +90,7 @@ An example [oauth2-proxy.cfg]({{ site.gitweb }}/contrib/oauth2-proxy.cfg.example | |||
| | `--proxy-prefix` | string | the url root path that this proxy should be nested under (e.g. /`<oauth2>/sign_in`) | `"/oauth2"` | | ||||
| | `--proxy-websockets` | bool | enables WebSocket proxying | true | | ||||
| | `--pubjwk-url` | string | JWK pubkey access endpoint: required by login.gov | | | ||||
| | `--real-client-ip-header` | string | Header used to determine the real IP of the client, requires `--reverse-proxy` to be set (one of: X-Forwarded-For, X-Real-IP, or X-ProxyUser-IP) | X-Real-IP | | ||||
| | `--redeem-url` | string | Token redemption endpoint | | | ||||
| | `--redirect-url` | string | the OAuth Redirect URL. ie: `"https://internalapp.yourcompany.com/oauth2/callback"` | | | ||||
| | `--redis-cluster-connection-urls` | string \| list | List of Redis cluster connection URLs (eg redis://HOST[:PORT]). Used in conjunction with `--redis-use-cluster` | | | ||||
|  |  | |||
							
								
								
									
										1
									
								
								main.go
								
								
								
								
							
							
						
						
									
										1
									
								
								main.go
								
								
								
								
							|  | @ -26,6 +26,7 @@ func main() { | |||
| 	flagSet.String("http-address", "127.0.0.1:4180", "[http://]<addr>:<port> or unix://<path> to listen on for HTTP clients") | ||||
| 	flagSet.String("https-address", ":443", "<addr>:<port> to listen on for HTTPS clients") | ||||
| 	flagSet.Bool("reverse-proxy", false, "are we running behind a reverse proxy, controls whether headers like X-Real-Ip are accepted") | ||||
| 	flagSet.String("real-client-ip-header", "X-Real-IP", "Header used to determine the real IP of the client (one of: X-Forwarded-For, X-Real-IP, or X-ProxyUser-IP)") | ||||
| 	flagSet.Bool("force-https", false, "force HTTPS redirect for HTTP requests") | ||||
| 	flagSet.String("tls-cert-file", "", "path to certificate file") | ||||
| 	flagSet.String("tls-key-file", "", "path to private key file") | ||||
|  |  | |||
|  | @ -112,6 +112,7 @@ type OAuthProxy struct { | |||
| 	jwtBearerVerifiers   []*oidc.IDTokenVerifier | ||||
| 	compiledRegex        []*regexp.Regexp | ||||
| 	templates            *template.Template | ||||
| 	realClientIPParser   realClientIPParser | ||||
| 	Banner               string | ||||
| 	Footer               string | ||||
| } | ||||
|  | @ -308,6 +309,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { | |||
| 		skipJwtBearerTokens:  opts.SkipJwtBearerTokens, | ||||
| 		jwtBearerVerifiers:   opts.jwtBearerVerifiers, | ||||
| 		compiledRegex:        opts.compiledRegex, | ||||
| 		realClientIPParser:   opts.realClientIPParser, | ||||
| 		SetXAuthRequest:      opts.SetXAuthRequest, | ||||
| 		PassBasicAuth:        opts.PassBasicAuth, | ||||
| 		SetBasicAuth:         opts.SetBasicAuth, | ||||
|  | @ -636,14 +638,6 @@ func (p *OAuthProxy) IsWhitelistedPath(path string) bool { | |||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func getRemoteAddr(req *http.Request) (s string) { | ||||
| 	s = req.RemoteAddr | ||||
| 	if req.Header.Get("X-Real-IP") != "" { | ||||
| 		s += fmt.Sprintf(" (%q)", req.Header.Get("X-Real-IP")) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
 | ||||
| var noCacheHeaders = map[string]string{ | ||||
| 	"Expires":         time.Unix(0, 0).Format(time.RFC1123), | ||||
|  | @ -766,7 +760,7 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { | |||
| // OAuthCallback is the OAuth2 authentication flow callback that finishes the
 | ||||
| // OAuth2 authentication flow
 | ||||
| func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | ||||
| 	remoteAddr := getRemoteAddr(req) | ||||
| 	remoteAddr := getClientString(p.realClientIPParser, req, true) | ||||
| 
 | ||||
| 	// finish the oauth cycle
 | ||||
| 	err := req.ParseForm() | ||||
|  | @ -894,7 +888,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	remoteAddr := getRemoteAddr(req) | ||||
| 	remoteAddr := getClientString(p.realClientIPParser, req, true) | ||||
| 	if session == nil { | ||||
| 		session, err = p.LoadCookiedSession(req) | ||||
| 		if err != nil { | ||||
|  |  | |||
							
								
								
									
										13
									
								
								options.go
								
								
								
								
							
							
						
						
									
										13
									
								
								options.go
								
								
								
								
							|  | @ -37,6 +37,7 @@ type Options struct { | |||
| 	HTTPAddress        string `flag:"http-address" cfg:"http_address" env:"OAUTH2_PROXY_HTTP_ADDRESS"` | ||||
| 	HTTPSAddress       string `flag:"https-address" cfg:"https_address" env:"OAUTH2_PROXY_HTTPS_ADDRESS"` | ||||
| 	ReverseProxy       bool   `flag:"reverse-proxy" cfg:"reverse_proxy" env:"OAUTH2_PROXY_REVERSE_PROXY"` | ||||
| 	RealClientIPHeader string `flag:"real-client-ip-header" cfg:"real_client_ip_header" env:"OAUTH2_PROXY_REAL_CLIENT_IP_HEADER"` | ||||
| 	ForceHTTPS         bool   `flag:"force-https" cfg:"force_https" env:"OAUTH2_PROXY_FORCE_HTTPS"` | ||||
| 	RedirectURL        string `flag:"redirect-url" cfg:"redirect_url" env:"OAUTH2_PROXY_REDIRECT_URL"` | ||||
| 	ClientID           string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"` | ||||
|  | @ -139,6 +140,7 @@ type Options struct { | |||
| 	signatureData      *SignatureData | ||||
| 	oidcVerifier       *oidc.IDTokenVerifier | ||||
| 	jwtBearerVerifiers []*oidc.IDTokenVerifier | ||||
| 	realClientIPParser realClientIPParser | ||||
| } | ||||
| 
 | ||||
| // SignatureData holds hmacauth signature hash and key
 | ||||
|  | @ -456,6 +458,13 @@ func (o *Options) Validate() error { | |||
| 	msgs = validateCookieName(o, msgs) | ||||
| 	msgs = setupLogger(o, msgs) | ||||
| 
 | ||||
| 	if o.ReverseProxy { | ||||
| 		o.realClientIPParser, err = getRealClientIPParser(o.RealClientIPHeader) | ||||
| 		if err != nil { | ||||
| 			msgs = append(msgs, fmt.Sprintf("real_client_ip_header (%s) not accepted parameter value: %v", o.RealClientIPHeader, err)) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if len(msgs) != 0 { | ||||
| 		return fmt.Errorf("invalid configuration:\n  %s", | ||||
| 			strings.Join(msgs, "\n  ")) | ||||
|  | @ -695,7 +704,9 @@ func setupLogger(o *Options, msgs []string) []string { | |||
| 	logger.SetStandardTemplate(o.StandardLoggingFormat) | ||||
| 	logger.SetAuthTemplate(o.AuthLoggingFormat) | ||||
| 	logger.SetReqTemplate(o.RequestLoggingFormat) | ||||
| 	logger.SetReverseProxy(o.ReverseProxy) | ||||
| 	logger.SetGetClientFunc(func(r *http.Request) string { | ||||
| 		return getClientString(o.realClientIPParser, r, false) | ||||
| 	}) | ||||
| 
 | ||||
| 	excludePaths := make([]string, 0) | ||||
| 	excludePaths = append(excludePaths, strings.Split(o.ExcludeLoggingPaths, ",")...) | ||||
|  |  | |||
|  | @ -326,3 +326,46 @@ func TestGCPHealthcheck(t *testing.T) { | |||
| 	o.GCPHealthChecks = true | ||||
| 	assert.Equal(t, nil, o.Validate()) | ||||
| } | ||||
| 
 | ||||
| func TestRealClientIPHeader(t *testing.T) { | ||||
| 	var o *Options | ||||
| 	var err error | ||||
| 	var expected string | ||||
| 
 | ||||
| 	// Ensure nil if ReverseProxy not set.
 | ||||
| 	o = testOptions() | ||||
| 	o.RealClientIPHeader = "X-Real-IP" | ||||
| 	assert.Equal(t, nil, o.Validate()) | ||||
| 	assert.Nil(t, o.realClientIPParser) | ||||
| 
 | ||||
| 	// Ensure simple use case works.
 | ||||
| 	o = testOptions() | ||||
| 	o.ReverseProxy = true | ||||
| 	o.RealClientIPHeader = "X-Forwarded-For" | ||||
| 	assert.Equal(t, nil, o.Validate()) | ||||
| 	assert.NotNil(t, o.realClientIPParser) | ||||
| 
 | ||||
| 	// Ensure unknown header format process an error.
 | ||||
| 	o = testOptions() | ||||
| 	o.ReverseProxy = true | ||||
| 	o.RealClientIPHeader = "Forwarded" | ||||
| 	err = o.Validate() | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	expected = errorMsg([]string{ | ||||
| 		"real_client_ip_header (Forwarded) not accepted parameter value: the http header key (Forwarded) is either invalid or unsupported", | ||||
| 	}) | ||||
| 	assert.Equal(t, expected, err.Error()) | ||||
| 	assert.Nil(t, o.realClientIPParser) | ||||
| 
 | ||||
| 	// Ensure invalid header format produces an error.
 | ||||
| 	o = testOptions() | ||||
| 	o.ReverseProxy = true | ||||
| 	o.RealClientIPHeader = "!934invalidheader-23:" | ||||
| 	err = o.Validate() | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	expected = errorMsg([]string{ | ||||
| 		"real_client_ip_header (!934invalidheader-23:) not accepted parameter value: the http header key (!934invalidheader-23:) is either invalid or unsupported", | ||||
| 	}) | ||||
| 	assert.Equal(t, expected, err.Error()) | ||||
| 	assert.Nil(t, o.realClientIPParser) | ||||
| } | ||||
|  |  | |||
|  | @ -3,7 +3,6 @@ package logger | |||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"os" | ||||
|  | @ -76,6 +75,9 @@ type reqLogMessageData struct { | |||
| 	Username string | ||||
| } | ||||
| 
 | ||||
| // Returns the apparent "real client IP" as a string.
 | ||||
| type GetClientFunc = func(r *http.Request) string | ||||
| 
 | ||||
| // A Logger represents an active logging object that generates lines of
 | ||||
| // output to an io.Writer passed through a formatter. Each logging
 | ||||
| // operation makes a single call to the Writer's Write method. A Logger
 | ||||
|  | @ -88,7 +90,7 @@ type Logger struct { | |||
| 	stdEnabled     bool | ||||
| 	authEnabled    bool | ||||
| 	reqEnabled     bool | ||||
| 	reverseProxy   bool | ||||
| 	getClientFunc  GetClientFunc | ||||
| 	excludePaths   map[string]struct{} | ||||
| 	stdLogTemplate *template.Template | ||||
| 	authTemplate   *template.Template | ||||
|  | @ -103,7 +105,7 @@ func New(flag int) *Logger { | |||
| 		stdEnabled:     true, | ||||
| 		authEnabled:    true, | ||||
| 		reqEnabled:     true, | ||||
| 		reverseProxy:   false, | ||||
| 		getClientFunc:  func(r *http.Request) string { return r.RemoteAddr }, | ||||
| 		excludePaths:   nil, | ||||
| 		stdLogTemplate: template.Must(template.New("std-log").Parse(DefaultStandardLoggingFormat)), | ||||
| 		authTemplate:   template.Must(template.New("auth-log").Parse(DefaultAuthLoggingFormat)), | ||||
|  | @ -153,7 +155,7 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu | |||
| 		username = "-" | ||||
| 	} | ||||
| 
 | ||||
| 	client := GetClient(req, l.reverseProxy) | ||||
| 	client := l.getClientFunc(req) | ||||
| 
 | ||||
| 	l.mu.Lock() | ||||
| 	defer l.mu.Unlock() | ||||
|  | @ -201,7 +203,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url. | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	client := GetClient(req, l.reverseProxy) | ||||
| 	client := l.getClientFunc(req) | ||||
| 
 | ||||
| 	l.mu.Lock() | ||||
| 	defer l.mu.Unlock() | ||||
|  | @ -252,22 +254,6 @@ func (l *Logger) GetFileLineString(calldepth int) string { | |||
| 	return fmt.Sprintf("%s:%d", file, line) | ||||
| } | ||||
| 
 | ||||
| // GetClient parses an HTTP request for the client/remote IP address.
 | ||||
| func GetClient(req *http.Request, reverseProxy bool) string { | ||||
| 	client := req.RemoteAddr | ||||
| 	if reverseProxy { | ||||
| 		if ip := req.Header.Get("X-Real-IP"); ip != "" { | ||||
| 			client = ip | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if c, _, err := net.SplitHostPort(client); err == nil { | ||||
| 		client = c | ||||
| 	} | ||||
| 
 | ||||
| 	return client | ||||
| } | ||||
| 
 | ||||
| // FormatTimestamp returns a formatted timestamp.
 | ||||
| func (l *Logger) FormatTimestamp(ts time.Time) string { | ||||
| 	if l.flag&LUTC != 0 { | ||||
|  | @ -312,11 +298,11 @@ func (l *Logger) SetReqEnabled(e bool) { | |||
| 	l.reqEnabled = e | ||||
| } | ||||
| 
 | ||||
| // SetReverseProxy controls whether logging will trust headers that can be set by a reverse proxy.
 | ||||
| func (l *Logger) SetReverseProxy(e bool) { | ||||
| // SetGetClientFunc sets the function which determines the apparent "real client IP".
 | ||||
| func (l *Logger) SetGetClientFunc(f GetClientFunc) { | ||||
| 	l.mu.Lock() | ||||
| 	defer l.mu.Unlock() | ||||
| 	l.reverseProxy = e | ||||
| 	l.getClientFunc = f | ||||
| } | ||||
| 
 | ||||
| // SetExcludePaths sets the paths to exclude from logging.
 | ||||
|  | @ -392,10 +378,10 @@ func SetReqEnabled(e bool) { | |||
| 	std.SetReqEnabled(e) | ||||
| } | ||||
| 
 | ||||
| // SetReverseProxy controls whether logging will trust headers that can be set
 | ||||
| // by a reverse proxy for the standard logger.
 | ||||
| func SetReverseProxy(e bool) { | ||||
| 	std.SetReverseProxy(e) | ||||
| // SetGetClientFunc sets the function which determines the apparent IP address
 | ||||
| // set by a reverse proxy for the standard logger.
 | ||||
| func SetGetClientFunc(f GetClientFunc) { | ||||
| 	std.SetGetClientFunc(f) | ||||
| } | ||||
| 
 | ||||
| // SetExcludePaths sets the path to exclude from logging, eg: health checks
 | ||||
|  |  | |||
|  | @ -0,0 +1,102 @@ | |||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
| ) | ||||
| 
 | ||||
| type realClientIPParser interface { | ||||
| 	GetRealClientIP(http.Header) (net.IP, error) | ||||
| } | ||||
| 
 | ||||
| func getRealClientIPParser(headerKey string) (realClientIPParser, error) { | ||||
| 	headerKey = http.CanonicalHeaderKey(headerKey) | ||||
| 
 | ||||
| 	switch headerKey { | ||||
| 	case http.CanonicalHeaderKey("X-Forwarded-For"), http.CanonicalHeaderKey("X-Real-IP"), http.CanonicalHeaderKey("X-ProxyUser-IP"): | ||||
| 		return &xForwardedForClientIPParser{header: headerKey}, nil | ||||
| 	} | ||||
| 
 | ||||
| 	// TODO: implement the more standardized but more complex `Forwarded` header.
 | ||||
| 	return nil, fmt.Errorf("the http header key (%s) is either invalid or unsupported", headerKey) | ||||
| } | ||||
| 
 | ||||
| type xForwardedForClientIPParser struct { | ||||
| 	header string | ||||
| } | ||||
| 
 | ||||
| // GetRealClientIP obtain the IP address of the end-user (not proxy).
 | ||||
| // Parses headers sharing the format as specified by:
 | ||||
| // * https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For.
 | ||||
| // Returns the `<client>` portion specified in the above document.
 | ||||
| // Additionally, is capable of parsing IPs with the port included, for v4 in the format "<ip>:<port>" and for v6 in the
 | ||||
| // format "[<ip>]:<port>".  With-port and without-port formats are seamlessly supported concurrently.
 | ||||
| func (p xForwardedForClientIPParser) GetRealClientIP(h http.Header) (net.IP, error) { | ||||
| 	var ipStr string | ||||
| 	if realIP := h.Get(p.header); realIP != "" { | ||||
| 		ipStr = realIP | ||||
| 	} else { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	// Each successive proxy may append itself, comma separated, to the end of the X-Forwarded-for header.
 | ||||
| 	// Select only the first IP listed, as it is the client IP recorded by the first proxy.
 | ||||
| 	if commaIndex := strings.IndexRune(ipStr, ','); commaIndex != -1 { | ||||
| 		ipStr = ipStr[:commaIndex] | ||||
| 	} | ||||
| 	ipStr = strings.TrimSpace(ipStr) | ||||
| 
 | ||||
| 	if ipHost, _, err := net.SplitHostPort(ipStr); err == nil { | ||||
| 		ipStr = ipHost | ||||
| 	} | ||||
| 
 | ||||
| 	ip := net.ParseIP(ipStr) | ||||
| 	if ip == nil { | ||||
| 		return nil, fmt.Errorf("unable to parse ip (%s) from %s header", ipStr, http.CanonicalHeaderKey(p.header)) | ||||
| 	} | ||||
| 
 | ||||
| 	return ip, nil | ||||
| } | ||||
| 
 | ||||
| // getRemoteIP obtains the IP of the low-level connected network host
 | ||||
| func getRemoteIP(req *http.Request) (net.IP, error) { | ||||
| 	if ipStr, _, err := net.SplitHostPort(req.RemoteAddr); err != nil { | ||||
| 		return nil, fmt.Errorf("unable to get ip and port from http.RemoteAddr (%s)", req.RemoteAddr) | ||||
| 	} else if ip := net.ParseIP(ipStr); ip != nil { | ||||
| 		return ip, nil | ||||
| 	} else { | ||||
| 		return nil, fmt.Errorf("unable to parse ip (%s)", ipStr) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // getClientString obtains the human readable string of the remote IP and optionally the real client IP if available
 | ||||
| func getClientString(p realClientIPParser, req *http.Request, full bool) (s string) { | ||||
| 	var realClientIPStr string | ||||
| 	if p != nil { | ||||
| 		if realClientIP, err := p.GetRealClientIP(req.Header); err != nil { | ||||
| 			logger.Printf("Unable to get real client IP: %v", err) | ||||
| 		} else if realClientIP != nil { | ||||
| 			realClientIPStr = realClientIP.String() | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	var remoteIPStr string | ||||
| 	if remoteIP, err := getRemoteIP(req); err == nil { | ||||
| 		remoteIPStr = remoteIP.String() | ||||
| 	} else { | ||||
| 		// Should not happen, if it does, likely a bug.
 | ||||
| 		logger.Printf("Unable to get remote IP(?!?!): %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if !full && realClientIPStr != "" { | ||||
| 		return realClientIPStr | ||||
| 	} | ||||
| 	if full && realClientIPStr != "" { | ||||
| 		return fmt.Sprintf("%s (%s)", remoteIPStr, realClientIPStr) | ||||
| 	} | ||||
| 	return remoteIPStr | ||||
| } | ||||
|  | @ -0,0 +1,176 @@ | |||
| package main | ||||
| 
 | ||||
| import ( | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
| func TestGetRealClientIPParser(t *testing.T) { | ||||
| 	forwardedForType := reflect.TypeOf((*xForwardedForClientIPParser)(nil)) | ||||
| 
 | ||||
| 	tests := []struct { | ||||
| 		header     string | ||||
| 		errString  string | ||||
| 		parserType reflect.Type | ||||
| 	}{ | ||||
| 		{"X-Forwarded-For", "", forwardedForType}, | ||||
| 		{"X-REAL-IP", "", forwardedForType}, | ||||
| 		{"x-proxyuser-ip", "", forwardedForType}, | ||||
| 		{"", "the http header key () is either invalid or unsupported", nil}, | ||||
| 		{"Forwarded", "the http header key (Forwarded) is either invalid or unsupported", nil}, | ||||
| 		{"2#* @##$$:kd", "the http header key (2#* @##$$:kd) is either invalid or unsupported", nil}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, test := range tests { | ||||
| 		p, err := getRealClientIPParser(test.header) | ||||
| 
 | ||||
| 		if test.errString == "" { | ||||
| 			assert.Nil(t, err) | ||||
| 		} else { | ||||
| 			assert.NotNil(t, err) | ||||
| 			assert.Equal(t, test.errString, err.Error()) | ||||
| 		} | ||||
| 
 | ||||
| 		if test.parserType == nil { | ||||
| 			assert.Nil(t, p) | ||||
| 		} else { | ||||
| 			assert.NotNil(t, p) | ||||
| 			assert.Equal(t, test.parserType, reflect.TypeOf(p)) | ||||
| 		} | ||||
| 
 | ||||
| 		if xp, ok := p.(*xForwardedForClientIPParser); ok { | ||||
| 			assert.Equal(t, http.CanonicalHeaderKey(test.header), xp.header) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestXForwardedForClientIPParser(t *testing.T) { | ||||
| 	p := &xForwardedForClientIPParser{header: http.CanonicalHeaderKey("X-Forwarded-For")} | ||||
| 
 | ||||
| 	tests := []struct { | ||||
| 		headerValue string | ||||
| 		errString   string | ||||
| 		expectedIP  net.IP | ||||
| 	}{ | ||||
| 		{"", "", nil}, | ||||
| 		{"1.2.3.4", "", net.ParseIP("1.2.3.4")}, | ||||
| 		{"10::23", "", net.ParseIP("10::23")}, | ||||
| 		{"::1", "", net.ParseIP("::1")}, | ||||
| 		{"[::1]:1234", "", net.ParseIP("::1")}, | ||||
| 		{"10.0.10.11:1234", "", net.ParseIP("10.0.10.11")}, | ||||
| 		{"192.168.10.50, 10.0.0.1, 1.2.3.4", "", net.ParseIP("192.168.10.50")}, | ||||
| 		{"nil", "unable to parse ip (nil) from X-Forwarded-For header", nil}, | ||||
| 		{"10000.10000.10000.10000", "unable to parse ip (10000.10000.10000.10000) from X-Forwarded-For header", nil}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, test := range tests { | ||||
| 		h := http.Header{} | ||||
| 		h.Add("X-Forwarded-For", test.headerValue) | ||||
| 
 | ||||
| 		ip, err := p.GetRealClientIP(h) | ||||
| 
 | ||||
| 		if test.errString == "" { | ||||
| 			assert.Nil(t, err) | ||||
| 		} else { | ||||
| 			assert.NotNil(t, err) | ||||
| 			assert.Equal(t, test.errString, err.Error()) | ||||
| 		} | ||||
| 
 | ||||
| 		if test.expectedIP == nil { | ||||
| 			assert.Nil(t, ip) | ||||
| 		} else { | ||||
| 			assert.NotNil(t, ip) | ||||
| 			assert.Equal(t, test.expectedIP, ip) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestXForwardedForClientIPParserIgnoresOthers(t *testing.T) { | ||||
| 	p := &xForwardedForClientIPParser{header: http.CanonicalHeaderKey("X-Forwarded-For")} | ||||
| 
 | ||||
| 	h := http.Header{} | ||||
| 	expectedIPString := "192.168.10.50" | ||||
| 	h.Add("X-Real-IP", "10.0.0.1") | ||||
| 	h.Add("X-ProxyUser-IP", "10.0.0.1") | ||||
| 	h.Add("X-Forwarded-For", expectedIPString) | ||||
| 	ip, err := p.GetRealClientIP(h) | ||||
| 	assert.Nil(t, err) | ||||
| 	assert.NotNil(t, ip) | ||||
| 	assert.Equal(t, ip, net.ParseIP(expectedIPString)) | ||||
| } | ||||
| 
 | ||||
| func TestGetRemoteIP(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		remoteAddr string | ||||
| 		errString  string | ||||
| 		expectedIP net.IP | ||||
| 	}{ | ||||
| 		{"", "unable to get ip and port from http.RemoteAddr ()", nil}, | ||||
| 		{"nil", "unable to get ip and port from http.RemoteAddr (nil)", nil}, | ||||
| 		{"235.28.129.186", "unable to get ip and port from http.RemoteAddr (235.28.129.186)", nil}, | ||||
| 		{"90::45", "unable to get ip and port from http.RemoteAddr (90::45)", nil}, | ||||
| 		{"192.168.73.165:14976, 10.4.201.15:18453", "unable to get ip and port from http.RemoteAddr (192.168.73.165:14976, 10.4.201.15:18453)", nil}, | ||||
| 		{"10000.10000.10000.10000:8080", "unable to parse ip (10000.10000.10000.10000)", nil}, | ||||
| 		{"[::1]:48290", "", net.ParseIP("::1")}, | ||||
| 		{"10.254.244.165:62750", "", net.ParseIP("10.254.244.165")}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, test := range tests { | ||||
| 		req := &http.Request{RemoteAddr: test.remoteAddr} | ||||
| 
 | ||||
| 		ip, err := getRemoteIP(req) | ||||
| 
 | ||||
| 		if test.errString == "" { | ||||
| 			assert.Nil(t, err) | ||||
| 		} else { | ||||
| 			assert.NotNil(t, err) | ||||
| 			assert.Equal(t, test.errString, err.Error()) | ||||
| 		} | ||||
| 
 | ||||
| 		if test.expectedIP == nil { | ||||
| 			assert.Nil(t, ip) | ||||
| 		} else { | ||||
| 			assert.NotNil(t, ip) | ||||
| 			assert.Equal(t, test.expectedIP, ip) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestGetClientString(t *testing.T) { | ||||
| 	p := &xForwardedForClientIPParser{header: http.CanonicalHeaderKey("X-Forwarded-For")} | ||||
| 
 | ||||
| 	tests := []struct { | ||||
| 		parser             realClientIPParser | ||||
| 		remoteAddr         string | ||||
| 		headerValue        string | ||||
| 		expectedClient     string | ||||
| 		expectedClientFull string | ||||
| 	}{ | ||||
| 		// Should fail quietly, only printing warnings to the log
 | ||||
| 		{nil, "", "", "", ""}, | ||||
| 		{p, "127.0.0.1:11950", "", "127.0.0.1", "127.0.0.1"}, | ||||
| 		{p, "[::1]:28660", "99.103.56.12", "99.103.56.12", "::1 (99.103.56.12)"}, | ||||
| 		{nil, "10.254.244.165:62750", "", "10.254.244.165", "10.254.244.165"}, | ||||
| 		// Parser is nil, the contents of X-Forwarded-For should be ignored in all cases.
 | ||||
| 		{nil, "[2001:470:26:307:a5a1:1177:2ae3:e9c3]:48290", "127.0.0.1", "2001:470:26:307:a5a1:1177:2ae3:e9c3", "2001:470:26:307:a5a1:1177:2ae3:e9c3"}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, test := range tests { | ||||
| 		h := http.Header{} | ||||
| 		h.Add("X-Forwarded-For", test.headerValue) | ||||
| 		req := &http.Request{ | ||||
| 			Header:     h, | ||||
| 			RemoteAddr: test.remoteAddr, | ||||
| 		} | ||||
| 
 | ||||
| 		client := getClientString(test.parser, req, false) | ||||
| 		assert.Equal(t, test.expectedClient, client) | ||||
| 
 | ||||
| 		clientFull := getClientString(test.parser, req, true) | ||||
| 		assert.Equal(t, test.expectedClientFull, clientFull) | ||||
| 	} | ||||
| } | ||||
		Loading…
	
		Reference in New Issue