diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 218b4426..b2abe379 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -657,6 +657,59 @@ func TestManualSignInStoresUserGroupsInTheSession(t *testing.T) { assert.Equal(t, userGroups, s.Groups) } +type ManualSignInValidator struct{} + +func (ManualSignInValidator) Validate(user, password string) bool { + switch { + case user == "admin" && password == "adminPass": + return true + default: + return false + } +} + +func ManualSignInWithCredentials(t *testing.T, user, pass string) int { + opts := baseTestOptions() + err := validation.Validate(opts) + if err != nil { + t.Fatal(err) + } + + proxy, err := NewOAuthProxy(opts, func(email string) bool { + return true + }) + if err != nil { + t.Fatal(err) + } + + proxy.basicAuthValidator = ManualSignInValidator{} + + rw := httptest.NewRecorder() + formData := url.Values{} + formData.Set("username", user) + formData.Set("password", pass) + signInReq, _ := http.NewRequest(http.MethodPost, "/oauth2/sign_in", strings.NewReader(formData.Encode())) + signInReq.Header.Add("Content-Type", "application/x-www-form-urlencoded") + proxy.ServeHTTP(rw, signInReq) + + return rw.Code +} + +func TestManualSignInEmptyUsernameAlert(t *testing.T) { + statusCode := ManualSignInWithCredentials(t, "", "") + assert.Equal(t, http.StatusBadRequest, statusCode) +} + +func TestManualSignInInvalidCredentialsAlert(t *testing.T) { + statusCode := ManualSignInWithCredentials(t, "admin", "") + assert.Equal(t, http.StatusUnauthorized, statusCode) +} + +func TestManualSignInCorrectCredentials(t *testing.T) { + statusCode := ManualSignInWithCredentials(t, "admin", "adminPass") + assert.Equal(t, http.StatusFound, statusCode) +} + func TestSignInPageIncludesTargetRedirect(t *testing.T) { sipTest, err := NewSignInPageTest(false) if err != nil { diff --git a/pkg/app/pagewriter/pagewriter_test.go b/pkg/app/pagewriter/pagewriter_test.go index 2adedd19..cfea153b 100644 --- a/pkg/app/pagewriter/pagewriter_test.go +++ b/pkg/app/pagewriter/pagewriter_test.go @@ -57,7 +57,7 @@ var _ = Describe("Writer", func() { It("Writes the default sign in template", func() { recorder := httptest.NewRecorder() - writer.WriteSignInPage(recorder, request, "/redirect") + writer.WriteSignInPage(recorder, request, "/redirect", http.StatusOK) body, err := ioutil.ReadAll(recorder.Result().Body) Expect(err).ToNot(HaveOccurred()) @@ -104,7 +104,7 @@ var _ = Describe("Writer", func() { It("Writes the custom sign in template", func() { recorder := httptest.NewRecorder() - writer.WriteSignInPage(recorder, request, "/redirect") + writer.WriteSignInPage(recorder, request, "/redirect", http.StatusOK) body, err := ioutil.ReadAll(recorder.Result().Body) Expect(err).ToNot(HaveOccurred()) @@ -151,7 +151,7 @@ var _ = Describe("Writer", func() { rw := httptest.NewRecorder() req := httptest.NewRequest("", "/sign-in", nil) redirectURL := "" - in.writer.WriteSignInPage(rw, req, redirectURL) + in.writer.WriteSignInPage(rw, req, redirectURL, http.StatusOK) Expect(rw.Result().StatusCode).To(Equal(in.expectedStatus)) @@ -166,7 +166,7 @@ var _ = Describe("Writer", func() { }), Entry("With an override function", writerFuncsTableInput{ writer: &WriterFuncs{ - SignInPageFunc: func(rw http.ResponseWriter, req *http.Request, redirectURL string) { + SignInPageFunc: func(rw http.ResponseWriter, req *http.Request, redirectURL string, statusCode int) { rw.WriteHeader(202) rw.Write([]byte(fmt.Sprintf("%s %s", req.URL.Path, redirectURL))) }, diff --git a/pkg/app/pagewriter/sign_in_page_test.go b/pkg/app/pagewriter/sign_in_page_test.go index 804f45b0..05c0b2e9 100644 --- a/pkg/app/pagewriter/sign_in_page_test.go +++ b/pkg/app/pagewriter/sign_in_page_test.go @@ -54,7 +54,7 @@ var _ = Describe("SignIn Page", func() { Context("WriteSignInPage", func() { It("Writes the template to the response writer", func() { recorder := httptest.NewRecorder() - signInPage.WriteSignInPage(recorder, request, "/redirect") + signInPage.WriteSignInPage(recorder, request, "/redirect", http.StatusOK) body, err := ioutil.ReadAll(recorder.Result().Body) Expect(err).ToNot(HaveOccurred()) @@ -68,7 +68,7 @@ var _ = Describe("SignIn Page", func() { signInPage.template = tmpl recorder := httptest.NewRecorder() - signInPage.WriteSignInPage(recorder, request, "/redirect") + signInPage.WriteSignInPage(recorder, request, "/redirect", http.StatusOK) body, err := ioutil.ReadAll(recorder.Result().Body) Expect(err).ToNot(HaveOccurred())