Introduce Duration so that marshalling works for duration strings
This commit is contained in:
		
							parent
							
								
									ed92df3537
								
							
						
					
					
						commit
						b6d6f31ac1
					
				|  | @ -1,5 +1,11 @@ | |||
| package options | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| ) | ||||
| 
 | ||||
| // SecretSource references an individual secret value.
 | ||||
| // Only one source within the struct should be defined at any time.
 | ||||
| type SecretSource struct { | ||||
|  | @ -12,3 +18,45 @@ type SecretSource struct { | |||
| 	// FromFile expects a path to a file containing the secret value.
 | ||||
| 	FromFile string | ||||
| } | ||||
| 
 | ||||
| // Duration is an alias for time.Duration so that we can ensure the marshalling
 | ||||
| // and unmarshalling of string durations is done as users expect.
 | ||||
| // Intentional blank line below to keep this first part of the comment out of
 | ||||
| // any generated references.
 | ||||
| 
 | ||||
| // Duration is as string representation of a period of time.
 | ||||
| // A duration string is a is a possibly signed sequence of decimal numbers,
 | ||||
| // each with optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
 | ||||
| // Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
 | ||||
| type Duration time.Duration | ||||
| 
 | ||||
| // UnmarshalJSON parses the duration string and sets the value of duration
 | ||||
| // to the value of the duration string.
 | ||||
| func (d *Duration) UnmarshalJSON(data []byte) error { | ||||
| 	input := string(data) | ||||
| 	if unquoted, err := strconv.Unquote(input); err == nil { | ||||
| 		input = unquoted | ||||
| 	} | ||||
| 
 | ||||
| 	du, err := time.ParseDuration(input) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	*d = Duration(du) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // MarshalJSON ensures that when the string is marshalled to JSON as a human
 | ||||
| // readable string.
 | ||||
| func (d *Duration) MarshalJSON() ([]byte, error) { | ||||
| 	dStr := fmt.Sprintf("%q", d.Duration().String()) | ||||
| 	return []byte(dStr), nil | ||||
| } | ||||
| 
 | ||||
| // Duration returns the time.Duration version of this Duration
 | ||||
| func (d *Duration) Duration() time.Duration { | ||||
| 	if d == nil { | ||||
| 		return time.Duration(0) | ||||
| 	} | ||||
| 	return time.Duration(*d) | ||||
| } | ||||
|  |  | |||
|  | @ -0,0 +1,88 @@ | |||
| package options | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"time" | ||||
| 
 | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| var _ = Describe("Common", func() { | ||||
| 	Context("Duration", func() { | ||||
| 		type marshalJSONTableInput struct { | ||||
| 			duration     Duration | ||||
| 			expectedJSON string | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("MarshalJSON", | ||||
| 			func(in marshalJSONTableInput) { | ||||
| 				data, err := in.duration.MarshalJSON() | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(string(data)).To(Equal(in.expectedJSON)) | ||||
| 
 | ||||
| 				var d Duration | ||||
| 				Expect(json.Unmarshal(data, &d)).To(Succeed()) | ||||
| 				Expect(d).To(Equal(in.duration)) | ||||
| 			}, | ||||
| 			Entry("30 seconds", marshalJSONTableInput{ | ||||
| 				duration:     Duration(30 * time.Second), | ||||
| 				expectedJSON: "\"30s\"", | ||||
| 			}), | ||||
| 			Entry("1 minute", marshalJSONTableInput{ | ||||
| 				duration:     Duration(1 * time.Minute), | ||||
| 				expectedJSON: "\"1m0s\"", | ||||
| 			}), | ||||
| 			Entry("1 hour 15 minutes", marshalJSONTableInput{ | ||||
| 				duration:     Duration(75 * time.Minute), | ||||
| 				expectedJSON: "\"1h15m0s\"", | ||||
| 			}), | ||||
| 			Entry("A zero Duration", marshalJSONTableInput{ | ||||
| 				duration:     Duration(0), | ||||
| 				expectedJSON: "\"0s\"", | ||||
| 			}), | ||||
| 		) | ||||
| 
 | ||||
| 		type unmarshalJSONTableInput struct { | ||||
| 			json             string | ||||
| 			expectedErr      error | ||||
| 			expectedDuration Duration | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("UnmarshalJSON", | ||||
| 			func(in unmarshalJSONTableInput) { | ||||
| 				// A duration must be initialised pointer before UnmarshalJSON will work.
 | ||||
| 				zero := Duration(0) | ||||
| 				d := &zero | ||||
| 
 | ||||
| 				err := d.UnmarshalJSON([]byte(in.json)) | ||||
| 				if in.expectedErr != nil { | ||||
| 					Expect(err).To(MatchError(in.expectedErr.Error())) | ||||
| 				} else { | ||||
| 					Expect(err).ToNot(HaveOccurred()) | ||||
| 				} | ||||
| 				Expect(d).ToNot(BeNil()) | ||||
| 				Expect(*d).To(Equal(in.expectedDuration)) | ||||
| 			}, | ||||
| 			Entry("1m", unmarshalJSONTableInput{ | ||||
| 				json:             "\"1m\"", | ||||
| 				expectedDuration: Duration(1 * time.Minute), | ||||
| 			}), | ||||
| 			Entry("30s", unmarshalJSONTableInput{ | ||||
| 				json:             "\"30s\"", | ||||
| 				expectedDuration: Duration(30 * time.Second), | ||||
| 			}), | ||||
| 			Entry("1h15m", unmarshalJSONTableInput{ | ||||
| 				json:             "\"1h15m\"", | ||||
| 				expectedDuration: Duration(75 * time.Minute), | ||||
| 			}), | ||||
| 			Entry("am", unmarshalJSONTableInput{ | ||||
| 				json:             "\"am\"", | ||||
| 				expectedErr:      errors.New("time: invalid duration \"am\""), | ||||
| 				expectedDuration: Duration(0), | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| }) | ||||
|  | @ -84,6 +84,7 @@ func (l *LegacyUpstreams) convert() (Upstreams, error) { | |||
| 			u.Path = "/" | ||||
| 		} | ||||
| 
 | ||||
| 		flushInterval := Duration(l.FlushInterval) | ||||
| 		upstream := Upstream{ | ||||
| 			ID:                    u.Path, | ||||
| 			Path:                  u.Path, | ||||
|  | @ -91,7 +92,7 @@ func (l *LegacyUpstreams) convert() (Upstreams, error) { | |||
| 			InsecureSkipTLSVerify: l.SSLUpstreamInsecureSkipVerify, | ||||
| 			PassHostHeader:        &l.PassHostHeader, | ||||
| 			ProxyWebSockets:       &l.ProxyWebSockets, | ||||
| 			FlushInterval:         &l.FlushInterval, | ||||
| 			FlushInterval:         &flushInterval, | ||||
| 		} | ||||
| 
 | ||||
| 		switch u.Scheme { | ||||
|  |  | |||
|  | @ -17,8 +17,8 @@ var _ = Describe("Legacy Options", func() { | |||
| 			legacyOpts := NewLegacyOptions() | ||||
| 
 | ||||
| 			// Set upstreams and related options to test their conversion
 | ||||
| 			flushInterval := 5 * time.Second | ||||
| 			legacyOpts.LegacyUpstreams.FlushInterval = flushInterval | ||||
| 			flushInterval := Duration(5 * time.Second) | ||||
| 			legacyOpts.LegacyUpstreams.FlushInterval = time.Duration(flushInterval) | ||||
| 			legacyOpts.LegacyUpstreams.PassHostHeader = true | ||||
| 			legacyOpts.LegacyUpstreams.ProxyWebSockets = true | ||||
| 			legacyOpts.LegacyUpstreams.SSLUpstreamInsecureSkipVerify = true | ||||
|  | @ -124,7 +124,7 @@ var _ = Describe("Legacy Options", func() { | |||
| 		skipVerify := true | ||||
| 		passHostHeader := false | ||||
| 		proxyWebSockets := true | ||||
| 		flushInterval := 5 * time.Second | ||||
| 		flushInterval := Duration(5 * time.Second) | ||||
| 
 | ||||
| 		// Test cases and expected outcomes
 | ||||
| 		validHTTP := "http://foo.bar/baz" | ||||
|  | @ -199,7 +199,7 @@ var _ = Describe("Legacy Options", func() { | |||
| 					SSLUpstreamInsecureSkipVerify: skipVerify, | ||||
| 					PassHostHeader:                passHostHeader, | ||||
| 					ProxyWebSockets:               proxyWebSockets, | ||||
| 					FlushInterval:                 flushInterval, | ||||
| 					FlushInterval:                 time.Duration(flushInterval), | ||||
| 				} | ||||
| 
 | ||||
| 				upstreams, err := legacyUpstreams.convert() | ||||
|  |  | |||
|  | @ -1,7 +1,5 @@ | |||
| package options | ||||
| 
 | ||||
| import "time" | ||||
| 
 | ||||
| // Upstreams is a collection of definitions for upstream servers.
 | ||||
| type Upstreams []Upstream | ||||
| 
 | ||||
|  | @ -47,7 +45,7 @@ type Upstream struct { | |||
| 	// FlushInterval is the period between flushing the response buffer when
 | ||||
| 	// streaming response from the upstream.
 | ||||
| 	// Defaults to 1 second.
 | ||||
| 	FlushInterval *time.Duration `json:"flushInterval,omitempty"` | ||||
| 	FlushInterval *Duration `json:"flushInterval,omitempty"` | ||||
| 
 | ||||
| 	// PassHostHeader determines whether the request host header should be proxied
 | ||||
| 	// to the upstream server.
 | ||||
|  |  | |||
|  | @ -98,7 +98,7 @@ func newReverseProxy(target *url.URL, upstream options.Upstream, errorHandler Pr | |||
| 
 | ||||
| 	// Configure options on the SingleHostReverseProxy
 | ||||
| 	if upstream.FlushInterval != nil { | ||||
| 		proxy.FlushInterval = *upstream.FlushInterval | ||||
| 		proxy.FlushInterval = upstream.FlushInterval.Duration() | ||||
| 	} else { | ||||
| 		proxy.FlushInterval = 1 * time.Second | ||||
| 	} | ||||
|  |  | |||
|  | @ -22,8 +22,8 @@ import ( | |||
| 
 | ||||
| var _ = Describe("HTTP Upstream Suite", func() { | ||||
| 
 | ||||
| 	const flushInterval5s = 5 * time.Second | ||||
| 	const flushInterval1s = 1 * time.Second | ||||
| 	const flushInterval5s = options.Duration(5 * time.Second) | ||||
| 	const flushInterval1s = options.Duration(1 * time.Second) | ||||
| 	truth := true | ||||
| 	falsum := false | ||||
| 
 | ||||
|  | @ -52,7 +52,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 
 | ||||
| 			rw := httptest.NewRecorder() | ||||
| 
 | ||||
| 			flush := 1 * time.Second | ||||
| 			flush := options.Duration(1 * time.Second) | ||||
| 
 | ||||
| 			upstream := options.Upstream{ | ||||
| 				ID:                    in.id, | ||||
|  | @ -258,7 +258,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 		req := httptest.NewRequest("", "http://example.localhost/foo", nil) | ||||
| 		rw := httptest.NewRecorder() | ||||
| 
 | ||||
| 		flush := 1 * time.Second | ||||
| 		flush := options.Duration(1 * time.Second) | ||||
| 		upstream := options.Upstream{ | ||||
| 			ID:                    "noPassHost", | ||||
| 			PassHostHeader:        &falsum, | ||||
|  | @ -290,7 +290,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 
 | ||||
| 	type newUpstreamTableInput struct { | ||||
| 		proxyWebSockets bool | ||||
| 		flushInterval   time.Duration | ||||
| 		flushInterval   options.Duration | ||||
| 		skipVerify      bool | ||||
| 		sigData         *options.SignatureData | ||||
| 		errorHandler    func(http.ResponseWriter, *http.Request, error) | ||||
|  | @ -319,7 +319,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 
 | ||||
| 			proxy, ok := upstreamProxy.handler.(*httputil.ReverseProxy) | ||||
| 			Expect(ok).To(BeTrue()) | ||||
| 			Expect(proxy.FlushInterval).To(Equal(in.flushInterval)) | ||||
| 			Expect(proxy.FlushInterval).To(Equal(in.flushInterval.Duration())) | ||||
| 			Expect(proxy.ErrorHandler != nil).To(Equal(in.errorHandler != nil)) | ||||
| 			if in.skipVerify { | ||||
| 				Expect(proxy.Transport).To(Equal(&http.Transport{ | ||||
|  | @ -370,7 +370,7 @@ var _ = Describe("HTTP Upstream Suite", func() { | |||
| 		var proxyServer *httptest.Server | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			flush := 1 * time.Second | ||||
| 			flush := options.Duration(1 * time.Second) | ||||
| 			upstream := options.Upstream{ | ||||
| 				ID:                    "websocketProxy", | ||||
| 				PassHostHeader:        &truth, | ||||
|  |  | |||
|  | @ -70,7 +70,7 @@ func validateStaticUpstream(upstream options.Upstream) []string { | |||
| 	if upstream.InsecureSkipTLSVerify { | ||||
| 		msgs = append(msgs, fmt.Sprintf("upstream %q has insecureSkipTLSVerify, but is a static upstream, this will have no effect.", upstream.ID)) | ||||
| 	} | ||||
| 	if upstream.FlushInterval != nil && *upstream.FlushInterval != time.Second { | ||||
| 	if upstream.FlushInterval != nil && upstream.FlushInterval.Duration() != time.Second { | ||||
| 		msgs = append(msgs, fmt.Sprintf("upstream %q has flushInterval, but is a static upstream, this will have no effect.", upstream.ID)) | ||||
| 	} | ||||
| 	if upstream.PassHostHeader != nil { | ||||
|  |  | |||
|  | @ -15,7 +15,7 @@ var _ = Describe("Upstreams", func() { | |||
| 		errStrings []string | ||||
| 	} | ||||
| 
 | ||||
| 	flushInterval := 5 * time.Second | ||||
| 	flushInterval := options.Duration(5 * time.Second) | ||||
| 	staticCode200 := 200 | ||||
| 	truth := true | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue