Catch more options errors at once; add test
This commit is contained in:
		
							parent
							
								
									30e5b636bf
								
							
						
					
					
						commit
						d751bbea4c
					
				
							
								
								
									
										25
									
								
								options.go
								
								
								
								
							
							
						
						
									
										25
									
								
								options.go
								
								
								
								
							| 
						 | 
					@ -1,10 +1,10 @@
 | 
				
			||||||
package main
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"errors"
 | 
					 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -45,29 +45,33 @@ func NewOptions() *Options {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (o *Options) Validate() error {
 | 
					func (o *Options) Validate() error {
 | 
				
			||||||
 | 
						msgs := make([]string, 0)
 | 
				
			||||||
	if len(o.Upstreams) < 1 {
 | 
						if len(o.Upstreams) < 1 {
 | 
				
			||||||
		return errors.New("missing setting: upstream")
 | 
							msgs = append(msgs, "missing setting: upstream")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if o.CookieSecret == "" {
 | 
						if o.CookieSecret == "" {
 | 
				
			||||||
		errors.New("missing setting: cookie-secret")
 | 
							msgs = append(msgs, "missing setting: cookie-secret")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if o.ClientID == "" {
 | 
						if o.ClientID == "" {
 | 
				
			||||||
		return errors.New("missing setting: client-id")
 | 
							msgs = append(msgs, "missing setting: client-id")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if o.ClientSecret == "" {
 | 
						if o.ClientSecret == "" {
 | 
				
			||||||
		return errors.New("missing setting: client-secret")
 | 
							msgs = append(msgs, "missing setting: client-secret")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	redirectUrl, err := url.Parse(o.RedirectUrl)
 | 
						redirectUrl, err := url.Parse(o.RedirectUrl)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return fmt.Errorf("error parsing redirect-url=%q %s", o.RedirectUrl, err)
 | 
							msgs = append(msgs, fmt.Sprintf(
 | 
				
			||||||
 | 
								"error parsing redirect-url=%q %s", o.RedirectUrl, err))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	o.redirectUrl = redirectUrl
 | 
						o.redirectUrl = redirectUrl
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, u := range o.Upstreams {
 | 
						for _, u := range o.Upstreams {
 | 
				
			||||||
		upstreamUrl, err := url.Parse(u)
 | 
							upstreamUrl, err := url.Parse(u)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return fmt.Errorf("error parsing upstream=%q %s", upstreamUrl, err)
 | 
								msgs = append(msgs, fmt.Sprintf(
 | 
				
			||||||
 | 
									"error parsing upstream=%q %s",
 | 
				
			||||||
 | 
									upstreamUrl, err))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if upstreamUrl.Path == "" {
 | 
							if upstreamUrl.Path == "" {
 | 
				
			||||||
			upstreamUrl.Path = "/"
 | 
								upstreamUrl.Path = "/"
 | 
				
			||||||
| 
						 | 
					@ -78,10 +82,15 @@ func (o *Options) Validate() error {
 | 
				
			||||||
	for _, u := range o.SkipAuthRegex {
 | 
						for _, u := range o.SkipAuthRegex {
 | 
				
			||||||
		CompiledRegex, err := regexp.Compile(u)
 | 
							CompiledRegex, err := regexp.Compile(u)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return fmt.Errorf("error compiling regex=%q %s", u, err)
 | 
								msgs = append(msgs, fmt.Sprintf(
 | 
				
			||||||
 | 
									"error compiling regex=%q %s", u, err))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		o.CompiledRegex = append(o.CompiledRegex, CompiledRegex)
 | 
							o.CompiledRegex = append(o.CompiledRegex, CompiledRegex)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(msgs) != 0 {
 | 
				
			||||||
 | 
							return fmt.Errorf("Invalid configuration:\n  %s",
 | 
				
			||||||
 | 
								strings.Join(msgs, "\n  "))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,92 @@
 | 
				
			||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/bmizerany/assert"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func testOptions() (*Options) {
 | 
				
			||||||
 | 
						o := NewOptions()
 | 
				
			||||||
 | 
						o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8080/")
 | 
				
			||||||
 | 
						o.CookieSecret = "foobar"
 | 
				
			||||||
 | 
						o.ClientID = "bazquux"
 | 
				
			||||||
 | 
						o.ClientSecret = "xyzzyplugh"
 | 
				
			||||||
 | 
						return o
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func errorMsg(msgs []string)(string) {
 | 
				
			||||||
 | 
						result := make([]string, 0)
 | 
				
			||||||
 | 
						result = append(result, "Invalid configuration:")
 | 
				
			||||||
 | 
						result = append(result, msgs...)
 | 
				
			||||||
 | 
						return strings.Join(result, "\n  ")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestNewOptions(t *testing.T) {
 | 
				
			||||||
 | 
						o := NewOptions()
 | 
				
			||||||
 | 
						err := o.Validate()
 | 
				
			||||||
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						expected := errorMsg([]string{
 | 
				
			||||||
 | 
							"missing setting: upstream",
 | 
				
			||||||
 | 
							"missing setting: cookie-secret",
 | 
				
			||||||
 | 
							"missing setting: client-id",
 | 
				
			||||||
 | 
							"missing setting: client-secret"})
 | 
				
			||||||
 | 
						assert.Equal(t, expected, err.Error())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestInitializedOptions(t *testing.T) {
 | 
				
			||||||
 | 
						o := testOptions()
 | 
				
			||||||
 | 
						assert.Equal(t, nil, o.Validate())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Note that it's not worth testing nonparseable URLs, since url.Parse()
 | 
				
			||||||
 | 
					// seems to parse damn near anything.
 | 
				
			||||||
 | 
					func TestRedirectUrl(t *testing.T) {
 | 
				
			||||||
 | 
						o := testOptions()
 | 
				
			||||||
 | 
						o.RedirectUrl = "https://myhost.com/oauth2/callback"
 | 
				
			||||||
 | 
						assert.Equal(t, nil, o.Validate())
 | 
				
			||||||
 | 
						expected := &url.URL{
 | 
				
			||||||
 | 
							Scheme: "https", Host: "myhost.com", Path: "/oauth2/callback"}
 | 
				
			||||||
 | 
						assert.Equal(t, expected, o.redirectUrl)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestProxyUrls(t *testing.T) {
 | 
				
			||||||
 | 
						o := testOptions()
 | 
				
			||||||
 | 
						o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8081")
 | 
				
			||||||
 | 
						assert.Equal(t, nil, o.Validate())
 | 
				
			||||||
 | 
						expected := []*url.URL{
 | 
				
			||||||
 | 
							&url.URL{Scheme: "http", Host: "127.0.0.1:8080", Path: "/"},
 | 
				
			||||||
 | 
							// note the '/' was added
 | 
				
			||||||
 | 
							&url.URL{Scheme: "http", Host: "127.0.0.1:8081", Path: "/"},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						assert.Equal(t, expected, o.proxyUrls)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCompiledRegex(t *testing.T) {
 | 
				
			||||||
 | 
						o := testOptions()
 | 
				
			||||||
 | 
						regexps := []string{"/foo/.*", "/ba[rz]/quux"}
 | 
				
			||||||
 | 
						o.SkipAuthRegex = regexps
 | 
				
			||||||
 | 
						assert.Equal(t, nil, o.Validate())
 | 
				
			||||||
 | 
						actual := make([]string, 0)
 | 
				
			||||||
 | 
						for _, regex := range o.CompiledRegex {
 | 
				
			||||||
 | 
							actual = append(actual, regex.String())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						assert.Equal(t, regexps, actual)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCompiledRegexError(t *testing.T) {
 | 
				
			||||||
 | 
						o := testOptions()
 | 
				
			||||||
 | 
						o.SkipAuthRegex = []string{"(foobaz", "barquux)"}
 | 
				
			||||||
 | 
						err := o.Validate()
 | 
				
			||||||
 | 
						assert.NotEqual(t, nil, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						expected := errorMsg([]string{
 | 
				
			||||||
 | 
							"error compiling regex=\"(foobaz\" error parsing regexp: " +
 | 
				
			||||||
 | 
								"missing closing ): `(foobaz`",
 | 
				
			||||||
 | 
							"error compiling regex=\"barquux)\" error parsing regexp: " +
 | 
				
			||||||
 | 
								"unexpected ): `barquux)`"})
 | 
				
			||||||
 | 
						assert.Equal(t, expected, err.Error())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Loading…
	
		Reference in New Issue