diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b91fd81..b0243de4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ ## Changes since v7.0.1 +- [#1104](https://github.com/oauth2-proxy/oauth2-proxy/pull/1104) Allow custom robots text pages (@JoelSpeed) - [#1045](https://github.com/oauth2-proxy/oauth2-proxy/pull/1045) Ensure redirect URI always has a scheme (@JoelSpeed) - [#1103](https://github.com/oauth2-proxy/oauth2-proxy/pull/1103) Deprecate upstream request signatures (@NickMeves) - [1087](https://github.com/oauth2-proxy/oauth2-proxy/pull/1087) Support Request ID in logging (@NickMeves) diff --git a/oauthproxy.go b/oauthproxy.go index 588d827a..e2b89ec7 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -577,13 +577,7 @@ func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) { // RobotsTxt disallows scraping pages from the OAuthProxy func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter, req *http.Request) { - _, err := fmt.Fprintf(rw, "User-agent: *\nDisallow: /") - if err != nil { - logger.Printf("Error writing robots.txt: %v", err) - p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) - return - } - rw.WriteHeader(http.StatusOK) + p.pageWriter.WriteRobotsTxt(rw, req) } // ErrorPage writes an error response diff --git a/oauthproxy_test.go b/oauthproxy_test.go index a2805c34..42b9b042 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -57,7 +57,7 @@ func TestRobotsTxt(t *testing.T) { req, _ := http.NewRequest("GET", "/robots.txt", nil) proxy.ServeHTTP(rw, req) assert.Equal(t, 200, rw.Code) - assert.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String()) + assert.Equal(t, "User-agent: *\nDisallow: /\n", rw.Body.String()) } func TestIsValidRedirect(t *testing.T) { diff --git a/pkg/app/pagewriter/pagewriter.go b/pkg/app/pagewriter/pagewriter.go index 3b7104b2..5991e625 100644 --- a/pkg/app/pagewriter/pagewriter.go +++ b/pkg/app/pagewriter/pagewriter.go @@ -13,12 +13,14 @@ type Writer interface { WriteSignInPage(rw http.ResponseWriter, req *http.Request, redirectURL string) WriteErrorPage(rw http.ResponseWriter, opts ErrorPageOpts) ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error) + WriteRobotsTxt(rw http.ResponseWriter, req *http.Request) } // pageWriter implements the Writer interface type pageWriter struct { *errorPageWriter *signInPageWriter + *staticPageWriter } // Opts contains all options required to configure the template @@ -88,8 +90,14 @@ func NewWriter(opts Opts) (Writer, error) { logoData: logoData, } + staticPages, err := newStaticPageWriter(opts.TemplatesPath, errorPage) + if err != nil { + return nil, fmt.Errorf("error loading static page writer: %v", err) + } + return &pageWriter{ errorPageWriter: errorPage, signInPageWriter: signInPage, + staticPageWriter: staticPages, }, nil } diff --git a/pkg/app/pagewriter/robots.txt b/pkg/app/pagewriter/robots.txt new file mode 100644 index 00000000..1f53798b --- /dev/null +++ b/pkg/app/pagewriter/robots.txt @@ -0,0 +1,2 @@ +User-agent: * +Disallow: / diff --git a/pkg/app/pagewriter/static_pages.go b/pkg/app/pagewriter/static_pages.go new file mode 100644 index 00000000..3e1bdcdc --- /dev/null +++ b/pkg/app/pagewriter/static_pages.go @@ -0,0 +1,118 @@ +package pagewriter + +import ( + // Import embed to allow importing default page templates + _ "embed" + "fmt" + "net/http" + "os" + "path/filepath" + + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" +) + +const ( + robotsTxtName = "robots.txt" +) + +//go:embed robots.txt +var defaultRobotsTxt []byte + +// staticPageWriter is used to write static pages. +type staticPageWriter struct { + pageGetter *pageGetter + errorPageWriter *errorPageWriter +} + +// WriteRobotsTxt writes the robots.txt content to the response writer. +func (s *staticPageWriter) WriteRobotsTxt(rw http.ResponseWriter, req *http.Request) { + s.writePage(rw, req, robotsTxtName) +} + +// writePage writes the content of the page to the response writer. +func (s *staticPageWriter) writePage(rw http.ResponseWriter, req *http.Request, pageName string) { + _, err := rw.Write(s.pageGetter.getPage(pageName)) + if err != nil { + logger.Printf("Error writing %q: %v", pageName, err) + scope := middlewareapi.GetRequestScope(req) + s.errorPageWriter.WriteErrorPage(rw, ErrorPageOpts{ + Status: http.StatusInternalServerError, + RequestID: scope.RequestID, + AppError: err.Error(), + }) + return + } +} + +func newStaticPageWriter(customDir string, errorWriter *errorPageWriter) (*staticPageWriter, error) { + pageGetter, err := loadStaticPages(customDir) + if err != nil { + return nil, fmt.Errorf("could not load static pages: %v", err) + } + + return &staticPageWriter{ + pageGetter: pageGetter, + errorPageWriter: errorWriter, + }, nil +} + +// loadStaticPages loads static page content from the custom directory provided. +// If any file is not provided in the custom directory, the default will be used +// instead. +// Statis files include: +// - robots.txt +func loadStaticPages(customDir string) (*pageGetter, error) { + pages := newPageGetter(customDir) + + if err := pages.addPage(robotsTxtName, defaultRobotsTxt); err != nil { + return nil, fmt.Errorf("could not add robots.txt: %v", err) + } + + return pages, nil +} + +// pageGetter is used to load and read page content for static pages. +type pageGetter struct { + pages map[string][]byte + dir string +} + +// newPageGetter creates a new page getter for the custom directory. +func newPageGetter(customDir string) *pageGetter { + return &pageGetter{ + pages: make(map[string][]byte), + dir: customDir, + } +} + +// addPage loads a new page into the pageGetter. +// If the given file name does not exist in the custom directory, the default +// content will be used instead. +func (p *pageGetter) addPage(fileName string, defaultContent []byte) error { + filePath := filepath.Join(p.dir, fileName) + if p.dir != "" && isFile(filePath) { + content, err := os.ReadFile(filePath) + if err != nil { + return fmt.Errorf("could not read file: %v", err) + } + + p.pages[fileName] = content + return nil + } + + // No custom content defined, use the default. + p.pages[fileName] = defaultContent + return nil +} + +// getPage returns the page content for a given page. +func (p *pageGetter) getPage(name string) []byte { + content, ok := p.pages[name] + if !ok { + // If the page isn't registered, something went wrong and there is a bug. + // Tests should make sure this code path is never hit. + panic(fmt.Sprintf("Static page %q not found", name)) + } + return content +} diff --git a/pkg/app/pagewriter/static_pages_test.go b/pkg/app/pagewriter/static_pages_test.go new file mode 100644 index 00000000..f52451ba --- /dev/null +++ b/pkg/app/pagewriter/static_pages_test.go @@ -0,0 +1,153 @@ +package pagewriter + +import ( + "errors" + "html/template" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + + middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Static Pages", func() { + var customDir string + const customRobots = "User-agent: *\nAllow: /\n" + var errorPage *errorPageWriter + var request *http.Request + + BeforeEach(func() { + errorTmpl, err := template.New("").Parse("{{.Title}}") + Expect(err).ToNot(HaveOccurred()) + errorPage = &errorPageWriter{ + template: errorTmpl, + } + + customDir, err = ioutil.TempDir("", "oauth2-proxy-static-pages-test") + Expect(err).ToNot(HaveOccurred()) + + robotsTxtFile := filepath.Join(customDir, robotsTxtName) + Expect(ioutil.WriteFile(robotsTxtFile, []byte(customRobots), 0400)).To(Succeed()) + + request = httptest.NewRequest("", "http://127.0.0.1/", nil) + request = middlewareapi.AddRequestScope(request, &middlewareapi.RequestScope{ + RequestID: testRequestID, + }) + }) + + AfterEach(func() { + Expect(os.RemoveAll(customDir)).To(Succeed()) + }) + + Context("Static Page Writer", func() { + Context("With custom content", func() { + var pageWriter *staticPageWriter + + BeforeEach(func() { + var err error + pageWriter, err = newStaticPageWriter(customDir, errorPage) + Expect(err).ToNot(HaveOccurred()) + }) + + Context("WriterRobotsTxt", func() { + It("Should write the custom robots txt", func() { + recorder := httptest.NewRecorder() + pageWriter.WriteRobotsTxt(recorder, request) + + body, err := ioutil.ReadAll(recorder.Result().Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal(customRobots)) + + Expect(recorder.Result().StatusCode).To(Equal(http.StatusOK)) + }) + }) + }) + + Context("Without custom content", func() { + var pageWriter *staticPageWriter + + BeforeEach(func() { + var err error + pageWriter, err = newStaticPageWriter("", errorPage) + Expect(err).ToNot(HaveOccurred()) + }) + + Context("WriterRobotsTxt", func() { + It("Should write the custom robots txt", func() { + recorder := httptest.NewRecorder() + pageWriter.WriteRobotsTxt(recorder, request) + + body, err := ioutil.ReadAll(recorder.Result().Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal(string(defaultRobotsTxt))) + + Expect(recorder.Result().StatusCode).To(Equal(http.StatusOK)) + }) + + It("Should serve an error if it cannot write the page", func() { + recorder := &testBadResponseWriter{ + ResponseRecorder: httptest.NewRecorder(), + } + pageWriter.WriteRobotsTxt(recorder, request) + + body, err := ioutil.ReadAll(recorder.Result().Body) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal(string("Internal Server Error"))) + + Expect(recorder.Result().StatusCode).To(Equal(http.StatusInternalServerError)) + }) + }) + }) + }) + + Context("loadStaticPages", func() { + Context("With custom content", func() { + Context("And a custom robots txt", func() { + It("Loads the custom content", func() { + pages, err := loadStaticPages(customDir) + Expect(err).ToNot(HaveOccurred()) + Expect(pages.pages).To(HaveLen(1)) + Expect(pages.getPage(robotsTxtName)).To(BeEquivalentTo(customRobots)) + }) + }) + + Context("And no custom robots txt", func() { + It("returns the default content", func() { + robotsTxtFile := filepath.Join(customDir, robotsTxtName) + Expect(os.Remove(robotsTxtFile)).To(Succeed()) + + pages, err := loadStaticPages(customDir) + Expect(err).ToNot(HaveOccurred()) + Expect(pages.pages).To(HaveLen(1)) + Expect(pages.getPage(robotsTxtName)).To(BeEquivalentTo(defaultRobotsTxt)) + }) + }) + }) + + Context("Without custom content", func() { + It("Loads the default content", func() { + pages, err := loadStaticPages("") + Expect(err).ToNot(HaveOccurred()) + Expect(pages.pages).To(HaveLen(1)) + Expect(pages.getPage(robotsTxtName)).To(BeEquivalentTo(defaultRobotsTxt)) + }) + }) + }) +}) + +type testBadResponseWriter struct { + *httptest.ResponseRecorder + firstWriteCalled bool +} + +func (b *testBadResponseWriter) Write(buf []byte) (int, error) { + if !b.firstWriteCalled { + b.firstWriteCalled = true + return 0, errors.New("write closed") + } + return b.ResponseRecorder.Write(buf) +}