Merge pull request #7 from pusher/migration
Migration from Bitly to Pusher
This commit is contained in:
		
						commit
						e1f45dd941
					
				|  | @ -0,0 +1,3 @@ | ||||||
|  | # Default owner should be a Pusher cloud-team member unless overridden by later | ||||||
|  | # rules in this file | ||||||
|  | * @pusher/cloud-team | ||||||
|  | @ -0,0 +1,37 @@ | ||||||
|  | <!--- Provide a general summary of the issue in the Title above --> | ||||||
|  | 
 | ||||||
|  | ## Expected Behavior | ||||||
|  | 
 | ||||||
|  | <!--- If you're describing a bug, tell us what should happen --> | ||||||
|  | <!--- If you're suggesting a change/improvement, tell us how it should work --> | ||||||
|  | 
 | ||||||
|  | ## Current Behavior | ||||||
|  | 
 | ||||||
|  | <!--- If describing a bug, tell us what happens instead of the expected behavior --> | ||||||
|  | <!--- If suggesting a change/improvement, explain the difference from current behavior --> | ||||||
|  | 
 | ||||||
|  | ## Possible Solution | ||||||
|  | 
 | ||||||
|  | <!--- Not obligatory, but suggest a fix/reason for the bug, --> | ||||||
|  | <!--- or ideas how to implement the addition or change --> | ||||||
|  | 
 | ||||||
|  | ## Steps to Reproduce (for bugs) | ||||||
|  | 
 | ||||||
|  | <!--- Provide a link to a live example, or an unambiguous set of steps to --> | ||||||
|  | <!--- reproduce this bug. Include code to reproduce, if relevant --> | ||||||
|  | 
 | ||||||
|  | 1.  <!--- Step 1 ---> | ||||||
|  | 2.  <!--- Step 2 ---> | ||||||
|  | 3.  <!--- Step 3 ---> | ||||||
|  | 4.  <!--- Step 4 ---> | ||||||
|  | 
 | ||||||
|  | ## Context | ||||||
|  | 
 | ||||||
|  | <!--- How has this issue affected you? What are you trying to accomplish? --> | ||||||
|  | <!--- Providing context helps us come up with a solution that is most useful in the real world --> | ||||||
|  | 
 | ||||||
|  | ## Your Environment | ||||||
|  | 
 | ||||||
|  | <!--- Include as many relevant details about the environment you experienced the bug in --> | ||||||
|  | 
 | ||||||
|  | - Version used: | ||||||
|  | @ -0,0 +1,25 @@ | ||||||
|  | <!--- Provide a general summary of your changes in the Title above --> | ||||||
|  | 
 | ||||||
|  | ## Description | ||||||
|  | 
 | ||||||
|  | <!--- Describe your changes in detail --> | ||||||
|  | 
 | ||||||
|  | ## Motivation and Context | ||||||
|  | 
 | ||||||
|  | <!--- Why is this change required? What problem does it solve? --> | ||||||
|  | <!--- If it fixes an open issue, please link to the issue here. --> | ||||||
|  | 
 | ||||||
|  | ## How Has This Been Tested? | ||||||
|  | 
 | ||||||
|  | <!--- Please describe in detail how you tested your changes. --> | ||||||
|  | <!--- Include details of your testing environment, and the tests you ran to --> | ||||||
|  | <!--- see how your change affects other areas of the code, etc. --> | ||||||
|  | 
 | ||||||
|  | ## Checklist: | ||||||
|  | 
 | ||||||
|  | <!--- Go over all the following points, and put an `x` in all the boxes that apply. --> | ||||||
|  | <!--- If you're unsure about any of these, don't hesitate to ask. We're here to help! --> | ||||||
|  | 
 | ||||||
|  | - [ ] My change requires a change to the documentation or CHANGELOG. | ||||||
|  | - [ ] I have updated the documentation/CHANGELOG accordingly. | ||||||
|  | - [ ] I have created a feature (non-master) branch for my PR. | ||||||
|  | @ -3,7 +3,7 @@ vendor | ||||||
| dist | dist | ||||||
| .godeps | .godeps | ||||||
| *.exe | *.exe | ||||||
| 
 | .env | ||||||
| 
 | 
 | ||||||
| # Go.gitignore | # Go.gitignore | ||||||
| # Compiled Object files, Static and Dynamic libs (Shared Objects) | # Compiled Object files, Static and Dynamic libs (Shared Objects) | ||||||
|  |  | ||||||
							
								
								
									
										12
									
								
								.travis.yml
								
								
								
								
							
							
						
						
									
										12
									
								
								.travis.yml
								
								
								
								
							|  | @ -1,12 +1,16 @@ | ||||||
| language: go | language: go | ||||||
| go: | go: | ||||||
|   - 1.8.x |  | ||||||
|   - 1.9.x |   - 1.9.x | ||||||
| script: |   - 1.10.x | ||||||
|  | install: | ||||||
|  |   # Fetch dependencies | ||||||
|   - wget -O dep https://github.com/golang/dep/releases/download/v0.3.2/dep-linux-amd64 |   - wget -O dep https://github.com/golang/dep/releases/download/v0.3.2/dep-linux-amd64 | ||||||
|   - chmod +x dep |   - chmod +x dep | ||||||
|   - ./dep ensure |   - mv dep $GOPATH/bin/dep | ||||||
|   - ./test.sh | script: | ||||||
|  |   - ./configure | ||||||
|  |   # Run tests | ||||||
|  |   - make test | ||||||
| sudo: false | sudo: false | ||||||
| notifications: | notifications: | ||||||
|   email: false |   email: false | ||||||
|  |  | ||||||
|  | @ -0,0 +1,23 @@ | ||||||
|  | # Vx.x.x (Pre-release) | ||||||
|  | 
 | ||||||
|  | ## Changes since v2.2: | ||||||
|  | 
 | ||||||
|  | - Move automated build to debian base image | ||||||
|  | - Add Makefile | ||||||
|  |   - Update CI to run `make test` | ||||||
|  |   - Update Dockerfile to use `make clean oauth2_proxy` | ||||||
|  |   - Update `VERSION` parameter to be set by `ldflags` from Git Status | ||||||
|  |   - Remove lint and test scripts | ||||||
|  | - Remove Go v1.8.x from Travis CI testing | ||||||
|  | - Add CODEOWNERS file | ||||||
|  | - Add CONTRIBUTING guide | ||||||
|  | - Add Issue and Pull Request templates | ||||||
|  | - Add Dockerfile | ||||||
|  | - Fix fsnotify import | ||||||
|  | - Update README to reflect new repository ownership | ||||||
|  | - Update CI scripts to separate linting and testing | ||||||
|  |   - Now using `gometalinter` for linting | ||||||
|  | - Move Go import path from `github.com/bitly/oauth2_proxy` to `github.com/pusher/oauth2_proxy` | ||||||
|  | - Repository forked on 27/11/18 | ||||||
|  |   - README updated to include note that this repository is forked | ||||||
|  |   - CHANGLOG created to track changes to repository from original fork | ||||||
|  | @ -0,0 +1,22 @@ | ||||||
|  | # Contributing | ||||||
|  | 
 | ||||||
|  | To develop on this project, please fork the repo and clone into your `$GOPATH`. | ||||||
|  | 
 | ||||||
|  | Dependencies are **not** checked in so please download those separately. | ||||||
|  | Download the dependencies using [`dep`](https://github.com/golang/dep). | ||||||
|  | 
 | ||||||
|  | ```bash | ||||||
|  | cd $GOPATH/src/github.com # Create this directory if it doesn't exist | ||||||
|  | git clone git@github.com:<YOUR_FORK>/oauth2_proxy pusher/oauth2_proxy | ||||||
|  | make dep | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | ## Pull Requests and Issues | ||||||
|  | 
 | ||||||
|  | We track bugs and issues using Github. | ||||||
|  | 
 | ||||||
|  | If you find a bug, please open an Issue. | ||||||
|  | 
 | ||||||
|  | If you want to fix a bug, please fork, create a feature branch, fix the bug and | ||||||
|  | open a PR back to this repo. | ||||||
|  | Please mention the open bug issue number within your PR if applicable. | ||||||
|  | @ -0,0 +1,16 @@ | ||||||
|  | FROM golang:1.10 AS builder | ||||||
|  | WORKDIR /go/src/github.com/pusher/oauth2_proxy | ||||||
|  | COPY . . | ||||||
|  | 
 | ||||||
|  | # Fetch dependencies | ||||||
|  | RUN go get -u github.com/golang/dep/cmd/dep | ||||||
|  | RUN dep ensure --vendor-only | ||||||
|  | 
 | ||||||
|  | # Build image | ||||||
|  | RUN ./configure && make clean oauth2_proxy | ||||||
|  | 
 | ||||||
|  | # Copy binary to debian | ||||||
|  | FROM debian:stretch | ||||||
|  | COPY --from=builder /go/src/github.com/pusher/oauth2_proxy/oauth2_proxy /bin/oauth2_proxy | ||||||
|  | 
 | ||||||
|  | ENTRYPOINT ["/bin/oauth2_proxy"] | ||||||
|  | @ -2,118 +2,149 @@ | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|  |   digest = "1:b24249f5a5e6fbe1eddc94b25973172339ccabeadef4779274f3ed0167c18812" | ||||||
|   name = "cloud.google.com/go" |   name = "cloud.google.com/go" | ||||||
|   packages = ["compute/metadata"] |   packages = ["compute/metadata"] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "2d3a6656c17a60b0815b7e06ab0be04eacb6e613" |   revision = "2d3a6656c17a60b0815b7e06ab0be04eacb6e613" | ||||||
|   version = "v0.16.0" |   version = "v0.16.0" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|  |   digest = "1:289dd4d7abfb3ad2b5f728fbe9b1d5c1bf7d265a3eb9ef92869af1f7baba4c7a" | ||||||
|   name = "github.com/BurntSushi/toml" |   name = "github.com/BurntSushi/toml" | ||||||
|   packages = ["."] |   packages = ["."] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "b26d9c308763d68093482582cea63d69be07a0f0" |   revision = "b26d9c308763d68093482582cea63d69be07a0f0" | ||||||
|   version = "v0.3.0" |   version = "v0.3.0" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|  |   digest = "1:512883404c2a99156e410e9880e3bb35ecccc0c07c1159eb204b5f3ef3c431b3" | ||||||
|   name = "github.com/bitly/go-simplejson" |   name = "github.com/bitly/go-simplejson" | ||||||
|   packages = ["."] |   packages = ["."] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "aabad6e819789e569bd6aabf444c935aa9ba1e44" |   revision = "aabad6e819789e569bd6aabf444c935aa9ba1e44" | ||||||
|   version = "v0.5.0" |   version = "v0.5.0" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|   branch = "v2" |   branch = "v2" | ||||||
|  |   digest = "1:e5a238f8fa890e529d7e493849bbae8988c9e70344e4630cc4f9a11b00516afb" | ||||||
|   name = "github.com/coreos/go-oidc" |   name = "github.com/coreos/go-oidc" | ||||||
|   packages = ["."] |   packages = ["."] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "77e7f2010a464ade7338597afe650dfcffbe2ca8" |   revision = "77e7f2010a464ade7338597afe650dfcffbe2ca8" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|  |   digest = "1:56c130d885a4aacae1dd9c7b71cfe39912c7ebc1ff7d2b46083c8812996dc43b" | ||||||
|   name = "github.com/davecgh/go-spew" |   name = "github.com/davecgh/go-spew" | ||||||
|   packages = ["spew"] |   packages = ["spew"] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "346938d642f2ec3594ed81d874461961cd0faa76" |   revision = "346938d642f2ec3594ed81d874461961cd0faa76" | ||||||
|   version = "v1.1.0" |   version = "v1.1.0" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|   branch = "master" |   branch = "master" | ||||||
|  |   digest = "1:3b760d3b93f994df8eb1d9ebfad17d3e9e37edcb7f7efaa15b427c0d7a64f4e4" | ||||||
|   name = "github.com/golang/protobuf" |   name = "github.com/golang/protobuf" | ||||||
|   packages = ["proto"] |   packages = ["proto"] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845" |   revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|  |   digest = "1:af67386ca553c04c6222f7b5b2f17bc97a5dfb3b81b706882c7fd8c72c30cf8f" | ||||||
|   name = "github.com/mbland/hmacauth" |   name = "github.com/mbland/hmacauth" | ||||||
|   packages = ["."] |   packages = ["."] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "107c17adcc5eccc9935cd67d9bc2feaf5255d2cb" |   revision = "107c17adcc5eccc9935cd67d9bc2feaf5255d2cb" | ||||||
|   version = "1.0.2" |   version = "1.0.2" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|   branch = "master" |   branch = "master" | ||||||
|  |   digest = "1:9408fb9c637c103010e5147469c232ce6b68edc840879cc730a2a15918e6cae8" | ||||||
|   name = "github.com/mreiferson/go-options" |   name = "github.com/mreiferson/go-options" | ||||||
|   packages = ["."] |   packages = ["."] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "77551d20752b54535462404ad9d877ebdb26e53d" |   revision = "77551d20752b54535462404ad9d877ebdb26e53d" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|  |   digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" | ||||||
|   name = "github.com/pmezard/go-difflib" |   name = "github.com/pmezard/go-difflib" | ||||||
|   packages = ["difflib"] |   packages = ["difflib"] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "792786c7400a136282c1664665ae0a8db921c6c2" |   revision = "792786c7400a136282c1664665ae0a8db921c6c2" | ||||||
|   version = "v1.0.0" |   version = "v1.0.0" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|   branch = "master" |   branch = "master" | ||||||
|  |   digest = "1:386e12afcfd8964907c92dffd106860c0dedd71dbefae14397b77b724a13343b" | ||||||
|   name = "github.com/pquerna/cachecontrol" |   name = "github.com/pquerna/cachecontrol" | ||||||
|   packages = [ |   packages = [ | ||||||
|     ".", |     ".", | ||||||
|     "cacheobject" |     "cacheobject", | ||||||
|   ] |   ] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "0dec1b30a0215bb68605dfc568e8855066c9202d" |   revision = "0dec1b30a0215bb68605dfc568e8855066c9202d" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|  |   digest = "1:3926a4ec9a4ff1a072458451aa2d9b98acd059a45b38f7335d31e06c3d6a0159" | ||||||
|   name = "github.com/stretchr/testify" |   name = "github.com/stretchr/testify" | ||||||
|   packages = ["assert"] |   packages = ["assert"] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0" |   revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0" | ||||||
|   version = "v1.1.4" |   version = "v1.1.4" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|   branch = "master" |   branch = "master" | ||||||
|  |   digest = "1:f6a006d27619a4d93bf9b66fe1999b8c8d1fa62bdc63af14f10fbe6fcaa2aa1a" | ||||||
|   name = "golang.org/x/crypto" |   name = "golang.org/x/crypto" | ||||||
|   packages = [ |   packages = [ | ||||||
|     "bcrypt", |     "bcrypt", | ||||||
|     "blowfish", |     "blowfish", | ||||||
|     "ed25519", |     "ed25519", | ||||||
|     "ed25519/internal/edwards25519" |     "ed25519/internal/edwards25519", | ||||||
|   ] |   ] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "9f005a07e0d31d45e6656d241bb5c0f2efd4bc94" |   revision = "9f005a07e0d31d45e6656d241bb5c0f2efd4bc94" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|   branch = "master" |   branch = "master" | ||||||
|  |   digest = "1:130b1bec86c62e121967ee0c69d9c263dc2d3ffe6c7c9a82aca4071c4d068861" | ||||||
|   name = "golang.org/x/net" |   name = "golang.org/x/net" | ||||||
|   packages = [ |   packages = [ | ||||||
|     "context", |     "context", | ||||||
|     "context/ctxhttp" |     "context/ctxhttp", | ||||||
|   ] |   ] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "9dfe39835686865bff950a07b394c12a98ddc811" |   revision = "9dfe39835686865bff950a07b394c12a98ddc811" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|   branch = "master" |   branch = "master" | ||||||
|  |   digest = "1:4a61176e8386727e4847b21a5a2625ce56b9c518bc543a28226503e701265db0" | ||||||
|   name = "golang.org/x/oauth2" |   name = "golang.org/x/oauth2" | ||||||
|   packages = [ |   packages = [ | ||||||
|     ".", |     ".", | ||||||
|     "google", |     "google", | ||||||
|     "internal", |     "internal", | ||||||
|     "jws", |     "jws", | ||||||
|     "jwt" |     "jwt", | ||||||
|   ] |   ] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "9ff8ebcc8e241d46f52ecc5bff0e5a2f2dbef402" |   revision = "9ff8ebcc8e241d46f52ecc5bff0e5a2f2dbef402" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|   branch = "master" |   branch = "master" | ||||||
|  |   digest = "1:dc1fb726dbbe79c86369941eae1e3b431b8fc6f11dbd37f7899dc758a43cc3ed" | ||||||
|   name = "google.golang.org/api" |   name = "google.golang.org/api" | ||||||
|   packages = [ |   packages = [ | ||||||
|     "admin/directory/v1", |     "admin/directory/v1", | ||||||
|     "gensupport", |     "gensupport", | ||||||
|     "googleapi", |     "googleapi", | ||||||
|     "googleapi/internal/uritemplates" |     "googleapi/internal/uritemplates", | ||||||
|   ] |   ] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "8791354e7ab150705ede13637a18c1fcc16b62e8" |   revision = "8791354e7ab150705ede13637a18c1fcc16b62e8" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|  |   digest = "1:934fb8966f303ede63aa405e2c8d7f0a427a05ea8df335dfdc1833dd4d40756f" | ||||||
|   name = "google.golang.org/appengine" |   name = "google.golang.org/appengine" | ||||||
|   packages = [ |   packages = [ | ||||||
|     ".", |     ".", | ||||||
|  | @ -125,30 +156,48 @@ | ||||||
|     "internal/modules", |     "internal/modules", | ||||||
|     "internal/remote_api", |     "internal/remote_api", | ||||||
|     "internal/urlfetch", |     "internal/urlfetch", | ||||||
|     "urlfetch" |     "urlfetch", | ||||||
|   ] |   ] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a" |   revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a" | ||||||
|   version = "v1.0.0" |   version = "v1.0.0" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|   name = "gopkg.in/fsnotify.v1" |   digest = "1:cb5b2a45a3dd41c01ff779c54ae4c8aab0271d6d3b3f734c8a8bd2c890299ef2" | ||||||
|  |   name = "gopkg.in/fsnotify/fsnotify.v1" | ||||||
|   packages = ["."] |   packages = ["."] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "836bfd95fecc0f1511dd66bdbf2b5b61ab8b00b6" |   revision = "836bfd95fecc0f1511dd66bdbf2b5b61ab8b00b6" | ||||||
|   version = "v1.2.11" |   version = "v1.2.11" | ||||||
| 
 | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|  |   digest = "1:be4ed0a2b15944dd777a663681a39260ed05f9c4e213017ed2e2255622c8820c" | ||||||
|   name = "gopkg.in/square/go-jose.v2" |   name = "gopkg.in/square/go-jose.v2" | ||||||
|   packages = [ |   packages = [ | ||||||
|     ".", |     ".", | ||||||
|     "cipher", |     "cipher", | ||||||
|     "json" |     "json", | ||||||
|   ] |   ] | ||||||
|  |   pruneopts = "" | ||||||
|   revision = "f8f38de21b4dcd69d0413faf231983f5fd6634b1" |   revision = "f8f38de21b4dcd69d0413faf231983f5fd6634b1" | ||||||
|   version = "v2.1.3" |   version = "v2.1.3" | ||||||
| 
 | 
 | ||||||
| [solve-meta] | [solve-meta] | ||||||
|   analyzer-name = "dep" |   analyzer-name = "dep" | ||||||
|   analyzer-version = 1 |   analyzer-version = 1 | ||||||
|   inputs-digest = "b502c41a61115d14d6379be26b0300f65d173bdad852f0170d387ebf2d7ec173" |   input-imports = [ | ||||||
|  |     "github.com/BurntSushi/toml", | ||||||
|  |     "github.com/bitly/go-simplejson", | ||||||
|  |     "github.com/coreos/go-oidc", | ||||||
|  |     "github.com/mbland/hmacauth", | ||||||
|  |     "github.com/mreiferson/go-options", | ||||||
|  |     "github.com/stretchr/testify/assert", | ||||||
|  |     "golang.org/x/crypto/bcrypt", | ||||||
|  |     "golang.org/x/oauth2", | ||||||
|  |     "golang.org/x/oauth2/google", | ||||||
|  |     "google.golang.org/api/admin/directory/v1", | ||||||
|  |     "google.golang.org/api/googleapi", | ||||||
|  |     "gopkg.in/fsnotify/fsnotify.v1", | ||||||
|  |   ] | ||||||
|   solver-name = "gps-cdcl" |   solver-name = "gps-cdcl" | ||||||
|   solver-version = 1 |   solver-version = 1 | ||||||
|  |  | ||||||
|  | @ -3,10 +3,6 @@ | ||||||
| # for detailed Gopkg.toml documentation. | # for detailed Gopkg.toml documentation. | ||||||
| # | # | ||||||
| 
 | 
 | ||||||
| [[constraint]] |  | ||||||
|   name = "github.com/18F/hmacauth" |  | ||||||
|   version = "~1.0.1" |  | ||||||
| 
 |  | ||||||
| [[constraint]] | [[constraint]] | ||||||
|   name = "github.com/BurntSushi/toml" |   name = "github.com/BurntSushi/toml" | ||||||
|   version = "~0.3.0" |   version = "~0.3.0" | ||||||
|  | @ -36,7 +32,7 @@ | ||||||
|   name = "google.golang.org/api" |   name = "google.golang.org/api" | ||||||
| 
 | 
 | ||||||
| [[constraint]] | [[constraint]] | ||||||
|   name = "gopkg.in/fsnotify.v1" |   name = "gopkg.in/fsnotify/fsnotify.v1" | ||||||
|   version = "~1.2.0" |   version = "~1.2.0" | ||||||
| 
 | 
 | ||||||
| [[constraint]] | [[constraint]] | ||||||
|  |  | ||||||
|  | @ -0,0 +1,56 @@ | ||||||
|  | include .env | ||||||
|  | BINARY := oauth2_proxy | ||||||
|  | VERSION := $(shell git describe --always --long --dirty --tags 2>/dev/null || echo "undefined") | ||||||
|  | .NOTPARALLEL: | ||||||
|  | 
 | ||||||
|  | .PHONY: all | ||||||
|  | all: dep lint $(BINARY) | ||||||
|  | 
 | ||||||
|  | .PHONY: clean | ||||||
|  | clean: | ||||||
|  | 	rm -rf release | ||||||
|  | 	rm -f $(BINARY) | ||||||
|  | 
 | ||||||
|  | .PHONY: distclean | ||||||
|  | distclean: clean | ||||||
|  | 	rm -rf vendor | ||||||
|  | 
 | ||||||
|  | BIN_DIR := $(GOPATH)/bin | ||||||
|  | GOMETALINTER := $(BIN_DIR)/gometalinter | ||||||
|  | 
 | ||||||
|  | $(GOMETALINTER): | ||||||
|  | 	$(GO) get -u github.com/alecthomas/gometalinter | ||||||
|  | 	gometalinter --install %> /dev/null | ||||||
|  | 
 | ||||||
|  | .PHONY: lint | ||||||
|  | lint: $(GOMETALINTER) | ||||||
|  | 	$(GOMETALINTER) --vendor --disable-all \
 | ||||||
|  | 		--enable=vet \
 | ||||||
|  | 		--enable=vetshadow \
 | ||||||
|  | 		--enable=golint \
 | ||||||
|  | 		--enable=ineffassign \
 | ||||||
|  | 		--enable=goconst \
 | ||||||
|  | 		--enable=deadcode \
 | ||||||
|  | 		--enable=gofmt \
 | ||||||
|  | 		--enable=goimports \
 | ||||||
|  | 		--tests ./... | ||||||
|  | 
 | ||||||
|  | .PHONY: dep | ||||||
|  | dep: | ||||||
|  | 	$(DEP) ensure --vendor-only | ||||||
|  | 
 | ||||||
|  | .PHONY: build | ||||||
|  | build: clean $(BINARY) | ||||||
|  | 
 | ||||||
|  | $(BINARY): | ||||||
|  | 	$(GO) build -ldflags="-X main.VERSION=${VERSION}" -o $(BINARY) github.com/pusher/oauth2_proxy | ||||||
|  | 
 | ||||||
|  | .PHONY: test | ||||||
|  | test: dep lint | ||||||
|  | 	$(GO) test -v -race $(go list ./... | grep -v /vendor/) | ||||||
|  | 
 | ||||||
|  | .PHONY: release | ||||||
|  | release: lint test | ||||||
|  | 	mkdir release | ||||||
|  | 	GOOS=darwin GOARCH=amd64 go build -ldflags="-X main.VERSION=${VERSION}" -o release/$(BINARY)-darwin-amd64 github.com/pusher/oauth2_proxy | ||||||
|  | 	GOOS=linux GOARCH=amd64 go build -ldflags="-X main.VERSION=${VERSION}" -o release/$(BINARY)-linux-amd64 github.com/pusher/oauth2_proxy | ||||||
							
								
								
									
										146
									
								
								README.md
								
								
								
								
							
							
						
						
									
										146
									
								
								README.md
								
								
								
								
							|  | @ -1,11 +1,13 @@ | ||||||
| oauth2_proxy | # oauth2_proxy | ||||||
| ================= |  | ||||||
| 
 | 
 | ||||||
| A reverse proxy and static file server that provides authentication using Providers (Google, GitHub, and others) | A reverse proxy and static file server that provides authentication using Providers (Google, GitHub, and others) | ||||||
| to validate accounts by email, domain or group. | to validate accounts by email, domain or group. | ||||||
| 
 | 
 | ||||||
| [](http://travis-ci.org/bitly/oauth2_proxy) | **Note:** This repository was forked from [bitly/OAuth2_Proxy](https://github.com/bitly/oauth2_proxy) on 27/11/2018. | ||||||
|  | Versions v3.0.0 and up are from this fork and will have diverged from any changes in the original fork. | ||||||
|  | A list of changes can be seen in the [CHANGELOG](CHANGELOG.md). | ||||||
| 
 | 
 | ||||||
|  | [](http://travis-ci.org/pusher/oauth2_proxy) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
| 
 | 
 | ||||||
|  | @ -15,15 +17,24 @@ to validate accounts by email, domain or group. | ||||||
| 
 | 
 | ||||||
| ## Installation | ## Installation | ||||||
| 
 | 
 | ||||||
| 1. Download [Prebuilt Binary](https://github.com/bitly/oauth2_proxy/releases) (current release is `v2.2`) or build with `$ go get github.com/bitly/oauth2_proxy` which will put the binary in `$GOROOT/bin` | 1.  Choose how to deploy: | ||||||
|  | 
 | ||||||
|  |     a. Download [Prebuilt Binary](https://github.com/pusher/oauth2_proxy/releases) (current release is `v2.2`) | ||||||
|  | 
 | ||||||
|  |     b. Build with `$ go get github.com/pusher/oauth2_proxy` which will put the binary in `$GOROOT/bin` | ||||||
|  | 
 | ||||||
|  |     c. Using the prebuilt docker image [quay.io/pusher/oauth2_proxy](https://quay.io/pusher/oauth2_proxy) | ||||||
|  | 
 | ||||||
| Prebuilt binaries can be validated by extracting the file and verifying it against the `sha256sum.txt` checksum file provided for each release starting with version `v2.3`. | Prebuilt binaries can be validated by extracting the file and verifying it against the `sha256sum.txt` checksum file provided for each release starting with version `v2.3`. | ||||||
|  | 
 | ||||||
| ``` | ``` | ||||||
| sha256sum -c sha256sum.txt 2>&1 | grep OK | sha256sum -c sha256sum.txt 2>&1 | grep OK | ||||||
| oauth2_proxy-2.3.linux-amd64: OK | oauth2_proxy-2.3.linux-amd64: OK | ||||||
| ``` | ``` | ||||||
| 2. Select a Provider and Register an OAuth Application with a Provider | 
 | ||||||
| 3. Configure OAuth2 Proxy using config file, command line options, or environment variables | 2.  Select a Provider and Register an OAuth Application with a Provider | ||||||
| 4. Configure SSL or Deploy behind a SSL endpoint (example provided for Nginx) | 3.  Configure OAuth2 Proxy using config file, command line options, or environment variables | ||||||
|  | 4.  Configure SSL or Deploy behind a SSL endpoint (example provided for Nginx) | ||||||
| 
 | 
 | ||||||
| ## OAuth Provider Configuration | ## OAuth Provider Configuration | ||||||
| 
 | 
 | ||||||
|  | @ -31,12 +42,12 @@ You will need to register an OAuth application with a Provider (Google, GitHub o | ||||||
| 
 | 
 | ||||||
| Valid providers are : | Valid providers are : | ||||||
| 
 | 
 | ||||||
| * [Google](#google-auth-provider) *default* | - [Google](#google-auth-provider) _default_ | ||||||
| * [Azure](#azure-auth-provider) | - [Azure](#azure-auth-provider) | ||||||
| * [Facebook](#facebook-auth-provider) | - [Facebook](#facebook-auth-provider) | ||||||
| * [GitHub](#github-auth-provider) | - [GitHub](#github-auth-provider) | ||||||
| * [GitLab](#gitlab-auth-provider) | - [GitLab](#gitlab-auth-provider) | ||||||
| * [LinkedIn](#linkedin-auth-provider) | - [LinkedIn](#linkedin-auth-provider) | ||||||
| 
 | 
 | ||||||
| The provider can be selected using the `provider` configuration value. | The provider can be selected using the `provider` configuration value. | ||||||
| 
 | 
 | ||||||
|  | @ -44,61 +55,62 @@ The provider can be selected using the `provider` configuration value. | ||||||
| 
 | 
 | ||||||
| For Google, the registration steps are: | For Google, the registration steps are: | ||||||
| 
 | 
 | ||||||
| 1. Create a new project: https://console.developers.google.com/project | 1.  Create a new project: https://console.developers.google.com/project | ||||||
| 2. Choose the new project from the top right project dropdown (only if another project is selected) | 2.  Choose the new project from the top right project dropdown (only if another project is selected) | ||||||
| 3. In the project Dashboard center pane, choose **"API Manager"** | 3.  In the project Dashboard center pane, choose **"API Manager"** | ||||||
| 4. In the left Nav pane, choose **"Credentials"** | 4.  In the left Nav pane, choose **"Credentials"** | ||||||
| 5. In the center pane, choose **"OAuth consent screen"** tab. Fill in **"Product name shown to users"** and hit save. | 5.  In the center pane, choose **"OAuth consent screen"** tab. Fill in **"Product name shown to users"** and hit save. | ||||||
| 6. In the center pane, choose **"Credentials"** tab. | 6.  In the center pane, choose **"Credentials"** tab. | ||||||
|    * Open the **"New credentials"** drop down |     - Open the **"New credentials"** drop down | ||||||
|    * Choose **"OAuth client ID"** |     - Choose **"OAuth client ID"** | ||||||
|    * Choose **"Web application"** |     - Choose **"Web application"** | ||||||
|    * Application name is freeform, choose something appropriate |     - Application name is freeform, choose something appropriate | ||||||
|    * Authorized JavaScript origins is your domain ex: `https://internal.yourcompany.com` |     - Authorized JavaScript origins is your domain ex: `https://internal.yourcompany.com` | ||||||
|    * Authorized redirect URIs is the location of oauth2/callback ex: `https://internal.yourcompany.com/oauth2/callback` |     - Authorized redirect URIs is the location of oauth2/callback ex: `https://internal.yourcompany.com/oauth2/callback` | ||||||
|    * Choose **"Create"** |     - Choose **"Create"** | ||||||
| 4. Take note of the **Client ID** and **Client Secret** | 7.  Take note of the **Client ID** and **Client Secret** | ||||||
| 
 | 
 | ||||||
| It's recommended to refresh sessions on a short interval (1h) with `cookie-refresh` setting which validates that the account is still authorized. | It's recommended to refresh sessions on a short interval (1h) with `cookie-refresh` setting which validates that the account is still authorized. | ||||||
| 
 | 
 | ||||||
| #### Restrict auth to specific Google groups on your domain. (optional) | #### Restrict auth to specific Google groups on your domain. (optional) | ||||||
| 
 | 
 | ||||||
| 1. Create a service account: https://developers.google.com/identity/protocols/OAuth2ServiceAccount and make sure to download the json file. | 1.  Create a service account: https://developers.google.com/identity/protocols/OAuth2ServiceAccount and make sure to download the json file. | ||||||
| 2. Make note of the Client ID for a future step. | 2.  Make note of the Client ID for a future step. | ||||||
| 3. Under "APIs & Auth", choose APIs. | 3.  Under "APIs & Auth", choose APIs. | ||||||
| 4. Click on Admin SDK and then Enable API. | 4.  Click on Admin SDK and then Enable API. | ||||||
| 5. Follow the steps on https://developers.google.com/admin-sdk/directory/v1/guides/delegation#delegate_domain-wide_authority_to_your_service_account and give the client id from step 2 the following oauth scopes: | 5.  Follow the steps on https://developers.google.com/admin-sdk/directory/v1/guides/delegation#delegate_domain-wide_authority_to_your_service_account and give the client id from step 2 the following oauth scopes: | ||||||
|  | 
 | ||||||
| ``` | ``` | ||||||
| https://www.googleapis.com/auth/admin.directory.group.readonly | https://www.googleapis.com/auth/admin.directory.group.readonly | ||||||
| https://www.googleapis.com/auth/admin.directory.user.readonly | https://www.googleapis.com/auth/admin.directory.user.readonly | ||||||
| ``` | ``` | ||||||
| 6. Follow the steps on https://support.google.com/a/answer/60757 to enable Admin API access. | 
 | ||||||
| 7. Create or choose an existing administrative email address on the Gmail domain to assign to the ```google-admin-email``` flag. This email will be impersonated by this client to make calls to the Admin SDK. See the note on the link from step 5 for the reason why. | 6.  Follow the steps on https://support.google.com/a/answer/60757 to enable Admin API access. | ||||||
| 8. Create or choose an existing email group and set that email to the ```google-group``` flag. You can pass multiple instances of this flag with different groups | 7.  Create or choose an existing administrative email address on the Gmail domain to assign to the `google-admin-email` flag. This email will be impersonated by this client to make calls to the Admin SDK. See the note on the link from step 5 for the reason why. | ||||||
| and the user will be checked against all the provided groups. | 8.  Create or choose an existing email group and set that email to the `google-group` flag. You can pass multiple instances of this flag with different groups | ||||||
| 9. Lock down the permissions on the json file downloaded from step 1 so only oauth2_proxy is able to read the file and set the path to the file in the ```google-service-account-json``` flag. |     and the user will be checked against all the provided groups. | ||||||
|  | 9.  Lock down the permissions on the json file downloaded from step 1 so only oauth2_proxy is able to read the file and set the path to the file in the `google-service-account-json` flag. | ||||||
| 10. Restart oauth2_proxy. | 10. Restart oauth2_proxy. | ||||||
| 
 | 
 | ||||||
| Note: The user is checked against the group members list on initial authentication and every time the token is refreshed ( about once an hour ). | Note: The user is checked against the group members list on initial authentication and every time the token is refreshed ( about once an hour ). | ||||||
| 
 | 
 | ||||||
| ### Azure Auth Provider | ### Azure Auth Provider | ||||||
| 
 | 
 | ||||||
| 1. [Add an application](https://azure.microsoft.com/en-us/documentation/articles/active-directory-integrating-applications/) to your Azure Active Directory tenant. | 1.  [Add an application](https://azure.microsoft.com/en-us/documentation/articles/active-directory-integrating-applications/) to your Azure Active Directory tenant. | ||||||
| 2. On the App properties page provide the correct Sign-On URL ie `https://internal.yourcompany.com/oauth2/callback` | 2.  On the App properties page provide the correct Sign-On URL ie `https://internal.yourcompany.com/oauth2/callback` | ||||||
| 3. If applicable take note of your `TenantID` and provide it via the `--azure-tenant=<YOUR TENANT ID>` commandline option. Default the `common` tenant is used. | 3.  If applicable take note of your `TenantID` and provide it via the `--azure-tenant=<YOUR TENANT ID>` commandline option. Default the `common` tenant is used. | ||||||
| 
 | 
 | ||||||
| The Azure AD auth provider uses `openid` as it default scope. It uses `https://graph.windows.net` as a default protected resource. It call to `https://graph.windows.net/me` to get the email address of the user that logs in. | The Azure AD auth provider uses `openid` as it default scope. It uses `https://graph.windows.net` as a default protected resource. It call to `https://graph.windows.net/me` to get the email address of the user that logs in. | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| ### Facebook Auth Provider | ### Facebook Auth Provider | ||||||
| 
 | 
 | ||||||
| 1. Create a new FB App from <https://developers.facebook.com/> | 1.  Create a new FB App from <https://developers.facebook.com/> | ||||||
| 2. Under FB Login, set your Valid OAuth redirect URIs to `https://internal.yourcompany.com/oauth2/callback` | 2.  Under FB Login, set your Valid OAuth redirect URIs to `https://internal.yourcompany.com/oauth2/callback` | ||||||
| 
 | 
 | ||||||
| ### GitHub Auth Provider | ### GitHub Auth Provider | ||||||
| 
 | 
 | ||||||
| 1. Create a new project: https://github.com/settings/developers | 1.  Create a new project: https://github.com/settings/developers | ||||||
| 2. Under `Authorization callback URL` enter the correct url ie `https://internal.yourcompany.com/oauth2/callback` | 2.  Under `Authorization callback URL` enter the correct url ie `https://internal.yourcompany.com/oauth2/callback` | ||||||
| 
 | 
 | ||||||
| The GitHub auth provider supports two additional parameters to restrict authentication to Organization or Team level access. Restricting by org and team is normally accompanied with `--email-domain=*` | The GitHub auth provider supports two additional parameters to restrict authentication to Organization or Team level access. Restricting by org and team is normally accompanied with `--email-domain=*` | ||||||
| 
 | 
 | ||||||
|  | @ -121,17 +133,16 @@ If you are using self-hosted GitLab, make sure you set the following to the appr | ||||||
|     -redeem-url="<your gitlab url>/oauth/token" |     -redeem-url="<your gitlab url>/oauth/token" | ||||||
|     -validate-url="<your gitlab url>/api/v4/user" |     -validate-url="<your gitlab url>/api/v4/user" | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| ### LinkedIn Auth Provider | ### LinkedIn Auth Provider | ||||||
| 
 | 
 | ||||||
| For LinkedIn, the registration steps are: | For LinkedIn, the registration steps are: | ||||||
| 
 | 
 | ||||||
| 1. Create a new project: https://www.linkedin.com/secure/developer | 1.  Create a new project: https://www.linkedin.com/secure/developer | ||||||
| 2. In the OAuth User Agreement section: | 2.  In the OAuth User Agreement section: | ||||||
|    * In default scope, select r_basicprofile and r_emailaddress. |     - In default scope, select r_basicprofile and r_emailaddress. | ||||||
|    * In "OAuth 2.0 Redirect URLs", enter `https://internal.yourcompany.com/oauth2/callback` |     - In "OAuth 2.0 Redirect URLs", enter `https://internal.yourcompany.com/oauth2/callback` | ||||||
| 3. Fill in the remaining required fields and Save. | 3.  Fill in the remaining required fields and Save. | ||||||
| 4. Take note of the **Consumer Key / API Key** and **Consumer Secret / Secret Key** | 4.  Take note of the **Consumer Key / API Key** and **Consumer Secret / Secret Key** | ||||||
| 
 | 
 | ||||||
| ### Microsoft Azure AD Provider | ### Microsoft Azure AD Provider | ||||||
| 
 | 
 | ||||||
|  | @ -143,9 +154,9 @@ Take note of your `TenantId` if applicable for your situation. The `TenantId` ca | ||||||
| 
 | 
 | ||||||
| OpenID Connect is a spec for OAUTH 2.0 + identity that is implemented by many major providers and several open source projects. This provider was originally built against CoreOS Dex and we will use it as an example. | OpenID Connect is a spec for OAUTH 2.0 + identity that is implemented by many major providers and several open source projects. This provider was originally built against CoreOS Dex and we will use it as an example. | ||||||
| 
 | 
 | ||||||
| 1. Launch a Dex instance using the [getting started guide](https://github.com/coreos/dex/blob/master/Documentation/getting-started.md). | 1.  Launch a Dex instance using the [getting started guide](https://github.com/coreos/dex/blob/master/Documentation/getting-started.md). | ||||||
| 2. Setup oauth2_proxy with the correct provider and using the default ports and callbacks. | 2.  Setup oauth2_proxy with the correct provider and using the default ports and callbacks. | ||||||
| 3. Login with the fixture use in the dex guide and run the oauth2_proxy with the following args: | 3.  Login with the fixture use in the dex guide and run the oauth2_proxy with the following args: | ||||||
| 
 | 
 | ||||||
|     -provider oidc |     -provider oidc | ||||||
|     -client-id oauth2_proxy |     -client-id oauth2_proxy | ||||||
|  | @ -253,7 +264,7 @@ The following environment variables can be used in place of the corresponding co | ||||||
| 
 | 
 | ||||||
| There are two recommended configurations. | There are two recommended configurations. | ||||||
| 
 | 
 | ||||||
| 1) Configure SSL Termination with OAuth2 Proxy by providing a `--tls-cert=/path/to/cert.pem` and `--tls-key=/path/to/cert.key`. | 1.  Configure SSL Termination with OAuth2 Proxy by providing a `--tls-cert=/path/to/cert.pem` and `--tls-key=/path/to/cert.key`. | ||||||
| 
 | 
 | ||||||
| The command line to run `oauth2_proxy` in this configuration would look like this: | The command line to run `oauth2_proxy` in this configuration would look like this: | ||||||
| 
 | 
 | ||||||
|  | @ -270,8 +281,7 @@ The command line to run `oauth2_proxy` in this configuration would look like thi | ||||||
|    --client-secret=... |    --client-secret=... | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| 
 | 2.  Configure SSL Termination with [Nginx](http://nginx.org/) (example config below), Amazon ELB, Google Cloud Platform Load Balancing, or .... | ||||||
| 2) Configure SSL Termination with [Nginx](http://nginx.org/) (example config below), Amazon ELB, Google Cloud Platform Load Balancing, or .... |  | ||||||
| 
 | 
 | ||||||
| Because `oauth2_proxy` listens on `127.0.0.1:4180` by default, to listen on all interfaces (needed when using an | Because `oauth2_proxy` listens on `127.0.0.1:4180` by default, to listen on all interfaces (needed when using an | ||||||
| external load balancer like Amazon ELB or Google Platform Load Balancing) use `--http-address="0.0.0.0:4180"` or | external load balancer like Amazon ELB or Google Platform Load Balancing) use `--http-address="0.0.0.0:4180"` or | ||||||
|  | @ -321,12 +331,12 @@ The command line to run `oauth2_proxy` in this configuration would look like thi | ||||||
| 
 | 
 | ||||||
| OAuth2 Proxy responds directly to the following endpoints. All other endpoints will be proxied upstream when authenticated. The `/oauth2` prefix can be changed with the `--proxy-prefix` config variable. | OAuth2 Proxy responds directly to the following endpoints. All other endpoints will be proxied upstream when authenticated. The `/oauth2` prefix can be changed with the `--proxy-prefix` config variable. | ||||||
| 
 | 
 | ||||||
| * /robots.txt - returns a 200 OK response that disallows all User-agents from all paths; see [robotstxt.org](http://www.robotstxt.org/) for more info | - /robots.txt - returns a 200 OK response that disallows all User-agents from all paths; see [robotstxt.org](http://www.robotstxt.org/) for more info | ||||||
| * /ping - returns an 200 OK response | - /ping - returns an 200 OK response | ||||||
| * /oauth2/sign_in - the login page, which also doubles as a sign out page (it clears cookies) | - /oauth2/sign_in - the login page, which also doubles as a sign out page (it clears cookies) | ||||||
| * /oauth2/start - a URL that will redirect to start the OAuth cycle | - /oauth2/start - a URL that will redirect to start the OAuth cycle | ||||||
| * /oauth2/callback - the URL used at the end of the OAuth cycle. The oauth app will be configured with this as the callback url. | - /oauth2/callback - the URL used at the end of the OAuth cycle. The oauth app will be configured with this as the callback url. | ||||||
| * /oauth2/auth - only returns a 202 Accepted response or a 401 Unauthorized response; for use with the [Nginx `auth_request` directive](#nginx-auth-request) | - /oauth2/auth - only returns a 202 Accepted response or a 401 Unauthorized response; for use with the [Nginx `auth_request` directive](#nginx-auth-request) | ||||||
| 
 | 
 | ||||||
| ## Request signatures | ## Request signatures | ||||||
| 
 | 
 | ||||||
|  | @ -341,9 +351,9 @@ in `oauthproxy.go`](./oauthproxy.go). | ||||||
| For more information about HMAC request signature validation, read the | For more information about HMAC request signature validation, read the | ||||||
| following: | following: | ||||||
| 
 | 
 | ||||||
| * [Amazon Web Services: Signing and Authenticating REST | - [Amazon Web Services: Signing and Authenticating REST | ||||||
|   Requests](https://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html) |   Requests](https://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html) | ||||||
| * [rc3.org: Using HMAC to authenticate Web service | - [rc3.org: Using HMAC to authenticate Web service | ||||||
|   requests](http://rc3.org/2011/12/02/using-hmac-to-authenticate-web-service-requests/) |   requests](http://rc3.org/2011/12/02/using-hmac-to-authenticate-web-service-requests/) | ||||||
| 
 | 
 | ||||||
| ## Logging Format | ## Logging Format | ||||||
|  | @ -417,3 +427,7 @@ server { | ||||||
|   } |   } | ||||||
| } | } | ||||||
| ``` | ``` | ||||||
|  | 
 | ||||||
|  | ## Contributing | ||||||
|  | 
 | ||||||
|  | Please see our [Contributing](CONTRIBUTING.md) guidelines. | ||||||
|  |  | ||||||
|  | @ -10,6 +10,7 @@ import ( | ||||||
| 	"github.com/bitly/go-simplejson" | 	"github.com/bitly/go-simplejson" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // Request parses the request body into a simplejson.Json object
 | ||||||
| func Request(req *http.Request) (*simplejson.Json, error) { | func Request(req *http.Request) (*simplejson.Json, error) { | ||||||
| 	resp, err := http.DefaultClient.Do(req) | 	resp, err := http.DefaultClient.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -32,7 +33,8 @@ func Request(req *http.Request) (*simplejson.Json, error) { | ||||||
| 	return data, nil | 	return data, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func RequestJson(req *http.Request, v interface{}) error { | // RequestJSON parses the request body into the given interface
 | ||||||
|  | func RequestJSON(req *http.Request, v interface{}) error { | ||||||
| 	resp, err := http.DefaultClient.Do(req) | 	resp, err := http.DefaultClient.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Printf("%s %s %s", req.Method, req.URL, err) | 		log.Printf("%s %s %s", req.Method, req.URL, err) | ||||||
|  | @ -50,6 +52,7 @@ func RequestJson(req *http.Request, v interface{}) error { | ||||||
| 	return json.Unmarshal(body, v) | 	return json.Unmarshal(body, v) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // RequestUnparsedResponse performs a GET and returns the raw response object
 | ||||||
| func RequestUnparsedResponse(url string, header http.Header) (resp *http.Response, err error) { | func RequestUnparsedResponse(url string, header http.Header) (resp *http.Response, err error) { | ||||||
| 	req, err := http.NewRequest("GET", url, nil) | 	req, err := http.NewRequest("GET", url, nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  |  | ||||||
|  | @ -1,20 +1,21 @@ | ||||||
| package api | package api | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"github.com/bitly/go-simplejson" |  | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/bitly/go-simplejson" | ||||||
|  | 
 | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func testBackend(response_code int, payload string) *httptest.Server { | func testBackend(responseCode int, payload string) *httptest.Server { | ||||||
| 	return httptest.NewServer(http.HandlerFunc( | 	return httptest.NewServer(http.HandlerFunc( | ||||||
| 		func(w http.ResponseWriter, r *http.Request) { | 		func(w http.ResponseWriter, r *http.Request) { | ||||||
| 			w.WriteHeader(response_code) | 			w.WriteHeader(responseCode) | ||||||
| 			w.Write([]byte(payload)) | 			w.Write([]byte(payload)) | ||||||
| 		})) | 		})) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -0,0 +1,137 @@ | ||||||
|  | #!/usr/bin/env bash | ||||||
|  | 
 | ||||||
|  | RED='\033[0;31m' | ||||||
|  | GREEN='\033[0;32m' | ||||||
|  | BLUE='\033[0;34m' | ||||||
|  | NC='\033[0m' | ||||||
|  | 
 | ||||||
|  | declare -A tools=() | ||||||
|  | declare -A desired=() | ||||||
|  | 
 | ||||||
|  | for arg in "$@"; do | ||||||
|  |   case ${arg%%=*} in | ||||||
|  |     "--with-go") | ||||||
|  |       desired[go]="${arg##*=}" | ||||||
|  |       ;; | ||||||
|  |     "--with-dep") | ||||||
|  |       desired[dep]="${arg##*=}" | ||||||
|  |       ;; | ||||||
|  |     "--help") | ||||||
|  |       printf "${GREEN}$0${NC}\n" | ||||||
|  |       printf "  available options:\n" | ||||||
|  |       printf "  --with-dep=${BLUE}<path_to_dep_binary>${NC}\n" | ||||||
|  |       printf "  --with-go=${BLUE}<path_to_go_binary>${NC}\n" | ||||||
|  |       exit 0 | ||||||
|  |       ;; | ||||||
|  |     *) | ||||||
|  |       echo "Unknown option: $arg" | ||||||
|  |       exit 2 | ||||||
|  |       ;; | ||||||
|  |     esac | ||||||
|  | done | ||||||
|  | 
 | ||||||
|  | vercomp () { | ||||||
|  |     if [[ $1 == $2 ]] | ||||||
|  |     then | ||||||
|  |         return 0 | ||||||
|  |     fi | ||||||
|  |     local IFS=. | ||||||
|  |     local i ver1=($1) ver2=($2) | ||||||
|  |     # fill empty fields in ver1 with zeros | ||||||
|  |     for ((i=${#ver1[@]}; i<${#ver2[@]}; i++)) | ||||||
|  |     do | ||||||
|  |         ver1[i]=0 | ||||||
|  |     done | ||||||
|  |     for ((i=0; i<${#ver1[@]}; i++)) | ||||||
|  |     do | ||||||
|  |         if [[ -z ${ver2[i]} ]] | ||||||
|  |         then | ||||||
|  |             # fill empty fields in ver2 with zeros | ||||||
|  |             ver2[i]=0 | ||||||
|  |         fi | ||||||
|  |         if ((10#${ver1[i]} > 10#${ver2[i]})) | ||||||
|  |         then | ||||||
|  |             return 1 | ||||||
|  |         fi | ||||||
|  |         if ((10#${ver1[i]} < 10#${ver2[i]})) | ||||||
|  |         then | ||||||
|  |             return 2 | ||||||
|  |         fi | ||||||
|  |     done | ||||||
|  |     return 0 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | check_for() { | ||||||
|  |   echo -n "Checking for $1... " | ||||||
|  |   if ! [ -z "${desired[$1]}" ]; then | ||||||
|  |     TOOL_PATH="${desired[$1]}" | ||||||
|  |   else | ||||||
|  |     TOOL_PATH=$(command -v $1) | ||||||
|  |   fi | ||||||
|  |   if ! [ -x "$TOOL_PATH" -a -f "$TOOL_PATH" ]; then | ||||||
|  |     printf "${RED}not found${NC}\n" | ||||||
|  |     cd - | ||||||
|  |     exit 1 | ||||||
|  |   else | ||||||
|  |     printf "${GREEN}found${NC}\n" | ||||||
|  |     tools[$1]=$TOOL_PATH | ||||||
|  |   fi | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | check_go_version() { | ||||||
|  |   echo -n "Checking go version... " | ||||||
|  |   GO_VERSION=$(${tools[go]} version | ${tools[awk]} '{where = match($0, /[0-9]\.[0-9]+\.[0-9]*/); if (where != 0) print substr($0, RSTART, RLENGTH)}') | ||||||
|  |   vercomp $GO_VERSION 1.9 | ||||||
|  |   case $? in | ||||||
|  |     0) ;& | ||||||
|  |     1) | ||||||
|  |       printf "${GREEN}" | ||||||
|  |       echo $GO_VERSION | ||||||
|  |       printf "${NC}" | ||||||
|  |       ;; | ||||||
|  |     2) | ||||||
|  |       printf "${RED}" | ||||||
|  |       echo "$GO_VERSION < 1.9" | ||||||
|  |       exit 1 | ||||||
|  |       ;; | ||||||
|  |   esac | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | check_docker_version() { | ||||||
|  |   echo -n "Checking docker version... " | ||||||
|  |   DOCKER_VERSION=$(${tools[docker]} version | ${tools[awk]}) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | check_go_env() { | ||||||
|  |   echo -n "Checking \$GOPATH... " | ||||||
|  |   if [ -z "$GOPATH" ]; then | ||||||
|  |     printf "${RED}invalid${NC} - GOPATH not set\n" | ||||||
|  |     exit 1 | ||||||
|  |   fi | ||||||
|  |   printf "${GREEN}valid${NC} - $GOPATH\n" | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | cd ${0%/*} | ||||||
|  | 
 | ||||||
|  | if [ ! -f .env ]; then | ||||||
|  |   rm .env | ||||||
|  | fi | ||||||
|  | 
 | ||||||
|  | check_for make | ||||||
|  | check_for awk | ||||||
|  | check_for go | ||||||
|  | check_go_version | ||||||
|  | check_go_env | ||||||
|  | check_for dep | ||||||
|  | 
 | ||||||
|  | echo | ||||||
|  | 
 | ||||||
|  | cat <<- EOF > .env | ||||||
|  | 	MAKE := "${tools[make]}" | ||||||
|  |   GO := "${tools[go]}" | ||||||
|  |   DEP := "${tools[dep]}" | ||||||
|  | EOF | ||||||
|  | 
 | ||||||
|  | echo "Environment configuration written to .env" | ||||||
|  | 
 | ||||||
|  | cd - > /dev/null | ||||||
|  | @ -24,10 +24,11 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { | func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { | ||||||
| 	const secret_b64 = "A3Xbr6fu6Al0HkgrP1ztjb-mYiwmxgNPP-XbNsz1WBk=" | 	const secretBase64 = "A3Xbr6fu6Al0HkgrP1ztjb-mYiwmxgNPP-XbNsz1WBk=" | ||||||
| 	const token = "my access token" | 	const token = "my access token" | ||||||
| 
 | 
 | ||||||
| 	secret, err := base64.URLEncoding.DecodeString(secret_b64) | 	secret, err := base64.URLEncoding.DecodeString(secretBase64) | ||||||
|  | 	assert.Equal(t, nil, err) | ||||||
| 	c, err := NewCipher([]byte(secret)) | 	c, err := NewCipher([]byte(secret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -5,6 +5,7 @@ import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // Nonce generates a random 16 byte string to be used as a nonce
 | ||||||
| func Nonce() (nonce string, err error) { | func Nonce() (nonce string, err error) { | ||||||
| 	b := make([]byte, 16) | 	b := make([]byte, 16) | ||||||
| 	_, err = rand.Read(b) | 	_, err = rand.Read(b) | ||||||
|  |  | ||||||
|  | @ -6,8 +6,14 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // EnvOptions holds program options loaded from the process environment
 | ||||||
| type EnvOptions map[string]interface{} | type EnvOptions map[string]interface{} | ||||||
| 
 | 
 | ||||||
|  | // LoadEnvForStruct loads environment variables for each field in an options
 | ||||||
|  | // struct passed into it.
 | ||||||
|  | //
 | ||||||
|  | // Fields in the options struct must have an `env` and `cfg` tag to be read
 | ||||||
|  | // from the environment
 | ||||||
| func (cfg EnvOptions) LoadEnvForStruct(options interface{}) { | func (cfg EnvOptions) LoadEnvForStruct(options interface{}) { | ||||||
| 	val := reflect.ValueOf(options).Elem() | 	val := reflect.ValueOf(options).Elem() | ||||||
| 	typ := val.Type() | 	typ := val.Type() | ||||||
|  |  | ||||||
							
								
								
									
										14
									
								
								htpasswd.go
								
								
								
								
							
							
						
						
									
										14
									
								
								htpasswd.go
								
								
								
								
							|  | @ -14,10 +14,12 @@ import ( | ||||||
| // Lookup passwords in a htpasswd file
 | // Lookup passwords in a htpasswd file
 | ||||||
| // Passwords must be generated with -B for bcrypt or -s for SHA1.
 | // Passwords must be generated with -B for bcrypt or -s for SHA1.
 | ||||||
| 
 | 
 | ||||||
|  | // HtpasswdFile represents the structure of an htpasswd file
 | ||||||
| type HtpasswdFile struct { | type HtpasswdFile struct { | ||||||
| 	Users map[string]string | 	Users map[string]string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // NewHtpasswdFromFile constructs an HtpasswdFile from the file at the path given
 | ||||||
| func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) { | func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) { | ||||||
| 	r, err := os.Open(path) | 	r, err := os.Open(path) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -27,13 +29,14 @@ func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) { | ||||||
| 	return NewHtpasswd(r) | 	return NewHtpasswd(r) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // NewHtpasswd  consctructs an HtpasswdFile from an io.Reader (opened file)
 | ||||||
| func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) { | func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) { | ||||||
| 	csv_reader := csv.NewReader(file) | 	csvReader := csv.NewReader(file) | ||||||
| 	csv_reader.Comma = ':' | 	csvReader.Comma = ':' | ||||||
| 	csv_reader.Comment = '#' | 	csvReader.Comment = '#' | ||||||
| 	csv_reader.TrimLeadingSpace = true | 	csvReader.TrimLeadingSpace = true | ||||||
| 
 | 
 | ||||||
| 	records, err := csv_reader.ReadAll() | 	records, err := csvReader.ReadAll() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | @ -44,6 +47,7 @@ func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) { | ||||||
| 	return h, nil | 	return h, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Validate checks a users password against the HtpasswdFile entries
 | ||||||
| func (h *HtpasswdFile) Validate(user string, password string) bool { | func (h *HtpasswdFile) Validate(user string, password string) bool { | ||||||
| 	realPassword, exists := h.Users[user] | 	realPassword, exists := h.Users[user] | ||||||
| 	if !exists { | 	if !exists { | ||||||
|  |  | ||||||
|  | @ -20,6 +20,7 @@ func TestSHA(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| func TestBcrypt(t *testing.T) { | func TestBcrypt(t *testing.T) { | ||||||
| 	hash1, err := bcrypt.GenerateFromPassword([]byte("password"), 1) | 	hash1, err := bcrypt.GenerateFromPassword([]byte("password"), 1) | ||||||
|  | 	assert.Equal(t, err, nil) | ||||||
| 	hash2, err := bcrypt.GenerateFromPassword([]byte("top-secret"), 2) | 	hash2, err := bcrypt.GenerateFromPassword([]byte("top-secret"), 2) | ||||||
| 	assert.Equal(t, err, nil) | 	assert.Equal(t, err, nil) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										16
									
								
								http.go
								
								
								
								
							
							
						
						
									
										16
									
								
								http.go
								
								
								
								
							|  | @ -9,11 +9,13 @@ import ( | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // Server represents an HTTP server
 | ||||||
| type Server struct { | type Server struct { | ||||||
| 	Handler http.Handler | 	Handler http.Handler | ||||||
| 	Opts    *Options | 	Opts    *Options | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ListenAndServe will serve traffic on HTTP or HTTPS depending on TLS options
 | ||||||
| func (s *Server) ListenAndServe() { | func (s *Server) ListenAndServe() { | ||||||
| 	if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" { | 	if s.Opts.TLSKeyFile != "" || s.Opts.TLSCertFile != "" { | ||||||
| 		s.ServeHTTPS() | 		s.ServeHTTPS() | ||||||
|  | @ -22,13 +24,14 @@ func (s *Server) ListenAndServe() { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ServeHTTP constructs a net.Listener and starts handling HTTP requests
 | ||||||
| func (s *Server) ServeHTTP() { | func (s *Server) ServeHTTP() { | ||||||
| 	httpAddress := s.Opts.HttpAddress | 	HTTPAddress := s.Opts.HTTPAddress | ||||||
| 	scheme := "" | 	var scheme string | ||||||
| 
 | 
 | ||||||
| 	i := strings.Index(httpAddress, "://") | 	i := strings.Index(HTTPAddress, "://") | ||||||
| 	if i > -1 { | 	if i > -1 { | ||||||
| 		scheme = httpAddress[0:i] | 		scheme = HTTPAddress[0:i] | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var networkType string | 	var networkType string | ||||||
|  | @ -39,7 +42,7 @@ func (s *Server) ServeHTTP() { | ||||||
| 		networkType = scheme | 		networkType = scheme | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	slice := strings.SplitN(httpAddress, "//", 2) | 	slice := strings.SplitN(HTTPAddress, "//", 2) | ||||||
| 	listenAddr := slice[len(slice)-1] | 	listenAddr := slice[len(slice)-1] | ||||||
| 
 | 
 | ||||||
| 	listener, err := net.Listen(networkType, listenAddr) | 	listener, err := net.Listen(networkType, listenAddr) | ||||||
|  | @ -57,8 +60,9 @@ func (s *Server) ServeHTTP() { | ||||||
| 	log.Printf("HTTP: closing %s", listener.Addr()) | 	log.Printf("HTTP: closing %s", listener.Addr()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
 | ||||||
| func (s *Server) ServeHTTPS() { | func (s *Server) ServeHTTPS() { | ||||||
| 	addr := s.Opts.HttpsAddress | 	addr := s.Opts.HTTPSAddress | ||||||
| 	config := &tls.Config{ | 	config := &tls.Config{ | ||||||
| 		MinVersion: tls.VersionTLS12, | 		MinVersion: tls.VersionTLS12, | ||||||
| 		MaxVersion: tls.VersionTLS12, | 		MaxVersion: tls.VersionTLS12, | ||||||
|  |  | ||||||
|  | @ -27,10 +27,13 @@ type responseLogger struct { | ||||||
| 	authInfo string | 	authInfo string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Header returns the ResponseWriter's Header
 | ||||||
| func (l *responseLogger) Header() http.Header { | func (l *responseLogger) Header() http.Header { | ||||||
| 	return l.w.Header() | 	return l.w.Header() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ExtractGAPMetadata extracts and removes GAP headers from the ResponseWriter's
 | ||||||
|  | // Header
 | ||||||
| func (l *responseLogger) ExtractGAPMetadata() { | func (l *responseLogger) ExtractGAPMetadata() { | ||||||
| 	upstream := l.w.Header().Get("GAP-Upstream-Address") | 	upstream := l.w.Header().Get("GAP-Upstream-Address") | ||||||
| 	if upstream != "" { | 	if upstream != "" { | ||||||
|  | @ -44,6 +47,7 @@ func (l *responseLogger) ExtractGAPMetadata() { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Write writes the response using the ResponseWriter
 | ||||||
| func (l *responseLogger) Write(b []byte) (int, error) { | func (l *responseLogger) Write(b []byte) (int, error) { | ||||||
| 	if l.status == 0 { | 	if l.status == 0 { | ||||||
| 		// The status will be StatusOK if WriteHeader has not been called yet
 | 		// The status will be StatusOK if WriteHeader has not been called yet
 | ||||||
|  | @ -55,16 +59,19 @@ func (l *responseLogger) Write(b []byte) (int, error) { | ||||||
| 	return size, err | 	return size, err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // WriteHeader writes the status code for the Response
 | ||||||
| func (l *responseLogger) WriteHeader(s int) { | func (l *responseLogger) WriteHeader(s int) { | ||||||
| 	l.ExtractGAPMetadata() | 	l.ExtractGAPMetadata() | ||||||
| 	l.w.WriteHeader(s) | 	l.w.WriteHeader(s) | ||||||
| 	l.status = s | 	l.status = s | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Status returns the response status code
 | ||||||
| func (l *responseLogger) Status() int { | func (l *responseLogger) Status() int { | ||||||
| 	return l.status | 	return l.status | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Size returns teh response size
 | ||||||
| func (l *responseLogger) Size() int { | func (l *responseLogger) Size() int { | ||||||
| 	return l.size | 	return l.size | ||||||
| } | } | ||||||
|  | @ -94,6 +101,7 @@ type loggingHandler struct { | ||||||
| 	logTemplate *template.Template | 	logTemplate *template.Template | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // LoggingHandler provides an http.Handler which logs requests to the HTTP server
 | ||||||
| func LoggingHandler(out io.Writer, h http.Handler, v bool, requestLoggingTpl string) http.Handler { | func LoggingHandler(out io.Writer, h http.Handler, v bool, requestLoggingTpl string) http.Handler { | ||||||
| 	return loggingHandler{ | 	return loggingHandler{ | ||||||
| 		writer:      out, | 		writer:      out, | ||||||
|  |  | ||||||
							
								
								
									
										2
									
								
								main.go
								
								
								
								
							
							
						
						
									
										2
									
								
								main.go
								
								
								
								
							|  | @ -84,7 +84,7 @@ func main() { | ||||||
| 	flagSet.Parse(os.Args[1:]) | 	flagSet.Parse(os.Args[1:]) | ||||||
| 
 | 
 | ||||||
| 	if *showVersion { | 	if *showVersion { | ||||||
| 		fmt.Printf("oauth2_proxy v%s (built with %s)\n", VERSION, runtime.Version()) | 		fmt.Printf("oauth2_proxy %s (built with %s)\n", VERSION, runtime.Version()) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -14,14 +14,23 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/bitly/oauth2_proxy/cookie" |  | ||||||
| 	"github.com/bitly/oauth2_proxy/providers" |  | ||||||
| 	"github.com/mbland/hmacauth" | 	"github.com/mbland/hmacauth" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/cookie" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/providers" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const SignatureHeader = "GAP-Signature" | const ( | ||||||
|  | 	// SignatureHeader is the name of the request header containing the GAP Signature
 | ||||||
|  | 	// Part of hmacauth
 | ||||||
|  | 	SignatureHeader = "GAP-Signature" | ||||||
| 
 | 
 | ||||||
| var SignatureHeaders []string = []string{ | 	httpScheme  = "http" | ||||||
|  | 	httpsScheme = "https" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // SignatureHeaders contains the headers to be signed by the hmac algorithm
 | ||||||
|  | // Part of hmacauth
 | ||||||
|  | var SignatureHeaders = []string{ | ||||||
| 	"Content-Length", | 	"Content-Length", | ||||||
| 	"Content-Md5", | 	"Content-Md5", | ||||||
| 	"Content-Type", | 	"Content-Type", | ||||||
|  | @ -34,13 +43,14 @@ var SignatureHeaders []string = []string{ | ||||||
| 	"Gap-Auth", | 	"Gap-Auth", | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // OAuthProxy is the main authentication proxy
 | ||||||
| type OAuthProxy struct { | type OAuthProxy struct { | ||||||
| 	CookieSeed     string | 	CookieSeed     string | ||||||
| 	CookieName     string | 	CookieName     string | ||||||
| 	CSRFCookieName string | 	CSRFCookieName string | ||||||
| 	CookieDomain   string | 	CookieDomain   string | ||||||
| 	CookieSecure   bool | 	CookieSecure   bool | ||||||
| 	CookieHttpOnly bool | 	CookieHTTPOnly bool | ||||||
| 	CookieExpire   time.Duration | 	CookieExpire   time.Duration | ||||||
| 	CookieRefresh  time.Duration | 	CookieRefresh  time.Duration | ||||||
| 	Validator      func(string) bool | 	Validator      func(string) bool | ||||||
|  | @ -74,12 +84,15 @@ type OAuthProxy struct { | ||||||
| 	Footer              string | 	Footer              string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // UpstreamProxy represents an upstream server to proxy to
 | ||||||
| type UpstreamProxy struct { | type UpstreamProxy struct { | ||||||
| 	upstream string | 	upstream string | ||||||
| 	handler  http.Handler | 	handler  http.Handler | ||||||
| 	auth     hmacauth.HmacAuth | 	auth     hmacauth.HmacAuth | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ServeHTTP proxies requests to the upstream provider while signing the
 | ||||||
|  | // request headers
 | ||||||
| func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { | func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { | ||||||
| 	w.Header().Set("GAP-Upstream-Address", u.upstream) | 	w.Header().Set("GAP-Upstream-Address", u.upstream) | ||||||
| 	if u.auth != nil { | 	if u.auth != nil { | ||||||
|  | @ -89,9 +102,12 @@ func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { | ||||||
| 	u.handler.ServeHTTP(w, r) | 	u.handler.ServeHTTP(w, r) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // NewReverseProxy creates a new reverse proxy for proxying requests to upstream
 | ||||||
|  | // servers
 | ||||||
| func NewReverseProxy(target *url.URL) (proxy *httputil.ReverseProxy) { | func NewReverseProxy(target *url.URL) (proxy *httputil.ReverseProxy) { | ||||||
| 	return httputil.NewSingleHostReverseProxy(target) | 	return httputil.NewSingleHostReverseProxy(target) | ||||||
| } | } | ||||||
|  | 
 | ||||||
| func setProxyUpstreamHostHeader(proxy *httputil.ReverseProxy, target *url.URL) { | func setProxyUpstreamHostHeader(proxy *httputil.ReverseProxy, target *url.URL) { | ||||||
| 	director := proxy.Director | 	director := proxy.Director | ||||||
| 	proxy.Director = func(req *http.Request) { | 	proxy.Director = func(req *http.Request) { | ||||||
|  | @ -102,6 +118,7 @@ func setProxyUpstreamHostHeader(proxy *httputil.ReverseProxy, target *url.URL) { | ||||||
| 		req.URL.RawQuery = "" | 		req.URL.RawQuery = "" | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
| func setProxyDirector(proxy *httputil.ReverseProxy) { | func setProxyDirector(proxy *httputil.ReverseProxy) { | ||||||
| 	director := proxy.Director | 	director := proxy.Director | ||||||
| 	proxy.Director = func(req *http.Request) { | 	proxy.Director = func(req *http.Request) { | ||||||
|  | @ -111,10 +128,13 @@ func setProxyDirector(proxy *httputil.ReverseProxy) { | ||||||
| 		req.URL.RawQuery = "" | 		req.URL.RawQuery = "" | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // NewFileServer creates a http.Handler to serve files from the filesystem
 | ||||||
| func NewFileServer(path string, filesystemPath string) (proxy http.Handler) { | func NewFileServer(path string, filesystemPath string) (proxy http.Handler) { | ||||||
| 	return http.StripPrefix(path, http.FileServer(http.Dir(filesystemPath))) | 	return http.StripPrefix(path, http.FileServer(http.Dir(filesystemPath))) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // NewOAuthProxy creates a new instance of OOuthProxy from the options provided
 | ||||||
| func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { | func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { | ||||||
| 	serveMux := http.NewServeMux() | 	serveMux := http.NewServeMux() | ||||||
| 	var auth hmacauth.HmacAuth | 	var auth hmacauth.HmacAuth | ||||||
|  | @ -125,7 +145,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { | ||||||
| 	for _, u := range opts.proxyURLs { | 	for _, u := range opts.proxyURLs { | ||||||
| 		path := u.Path | 		path := u.Path | ||||||
| 		switch u.Scheme { | 		switch u.Scheme { | ||||||
| 		case "http", "https": | 		case httpScheme, httpsScheme: | ||||||
| 			u.Path = "" | 			u.Path = "" | ||||||
| 			log.Printf("mapping path %q => upstream %q", path, u) | 			log.Printf("mapping path %q => upstream %q", path, u) | ||||||
| 			proxy := NewReverseProxy(u) | 			proxy := NewReverseProxy(u) | ||||||
|  | @ -160,7 +180,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { | ||||||
| 		refresh = fmt.Sprintf("after %s", opts.CookieRefresh) | 		refresh = fmt.Sprintf("after %s", opts.CookieRefresh) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, opts.CookieDomain, refresh) | 	log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHTTPOnly, opts.CookieExpire, opts.CookieDomain, refresh) | ||||||
| 
 | 
 | ||||||
| 	var cipher *cookie.Cipher | 	var cipher *cookie.Cipher | ||||||
| 	if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) { | 	if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) { | ||||||
|  | @ -177,7 +197,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { | ||||||
| 		CookieSeed:     opts.CookieSecret, | 		CookieSeed:     opts.CookieSecret, | ||||||
| 		CookieDomain:   opts.CookieDomain, | 		CookieDomain:   opts.CookieDomain, | ||||||
| 		CookieSecure:   opts.CookieSecure, | 		CookieSecure:   opts.CookieSecure, | ||||||
| 		CookieHttpOnly: opts.CookieHttpOnly, | 		CookieHTTPOnly: opts.CookieHTTPOnly, | ||||||
| 		CookieExpire:   opts.CookieExpire, | 		CookieExpire:   opts.CookieExpire, | ||||||
| 		CookieRefresh:  opts.CookieRefresh, | 		CookieRefresh:  opts.CookieRefresh, | ||||||
| 		Validator:      validator, | 		Validator:      validator, | ||||||
|  | @ -209,6 +229,8 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // GetRedirectURI returns the redirectURL that the upstream OAuth Provider will
 | ||||||
|  | // redirect clients to once authenticated
 | ||||||
| func (p *OAuthProxy) GetRedirectURI(host string) string { | func (p *OAuthProxy) GetRedirectURI(host string) string { | ||||||
| 	// default to the request Host if not set
 | 	// default to the request Host if not set
 | ||||||
| 	if p.redirectURL.Host != "" { | 	if p.redirectURL.Host != "" { | ||||||
|  | @ -218,9 +240,9 @@ func (p *OAuthProxy) GetRedirectURI(host string) string { | ||||||
| 	u = *p.redirectURL | 	u = *p.redirectURL | ||||||
| 	if u.Scheme == "" { | 	if u.Scheme == "" { | ||||||
| 		if p.CookieSecure { | 		if p.CookieSecure { | ||||||
| 			u.Scheme = "https" | 			u.Scheme = httpsScheme | ||||||
| 		} else { | 		} else { | ||||||
| 			u.Scheme = "http" | 			u.Scheme = httpScheme | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	u.Host = host | 	u.Host = host | ||||||
|  | @ -254,6 +276,8 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // MakeSessionCookie creates an http.Cookie containing the authenticated user's
 | ||||||
|  | // authentication details
 | ||||||
| func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { | func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { | ||||||
| 	if value != "" { | 	if value != "" { | ||||||
| 		value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now) | 		value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now) | ||||||
|  | @ -265,6 +289,7 @@ func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expirati | ||||||
| 	return p.makeCookie(req, p.CookieName, value, expiration, now) | 	return p.makeCookie(req, p.CookieName, value, expiration, now) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // MakeCSRFCookie creates a cookie for CSRF
 | ||||||
| func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { | func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { | ||||||
| 	return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) | 	return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) | ||||||
| } | } | ||||||
|  | @ -285,20 +310,25 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex | ||||||
| 		Value:    value, | 		Value:    value, | ||||||
| 		Path:     "/", | 		Path:     "/", | ||||||
| 		Domain:   p.CookieDomain, | 		Domain:   p.CookieDomain, | ||||||
| 		HttpOnly: p.CookieHttpOnly, | 		HttpOnly: p.CookieHTTPOnly, | ||||||
| 		Secure:   p.CookieSecure, | 		Secure:   p.CookieSecure, | ||||||
| 		Expires:  now.Add(expiration), | 		Expires:  now.Add(expiration), | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ClearCSRFCookie creates a cookie to unset the CSRF cookie stored in the user's
 | ||||||
|  | // session
 | ||||||
| func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) { | func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now())) | 	http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now())) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // SetCSRFCookie adds a CSRF cookie to the response
 | ||||||
| func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, val string) { | func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, val string) { | ||||||
| 	http.SetCookie(rw, p.MakeCSRFCookie(req, val, p.CookieExpire, time.Now())) | 	http.SetCookie(rw, p.MakeCSRFCookie(req, val, p.CookieExpire, time.Now())) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ClearSessionCookie creates a cookie to unset the user's authentication cookie
 | ||||||
|  | // stored in the user's session
 | ||||||
| func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { | func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	clr := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now()) | 	clr := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now()) | ||||||
| 	http.SetCookie(rw, clr) | 	http.SetCookie(rw, clr) | ||||||
|  | @ -311,10 +341,12 @@ func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Reques | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // SetSessionCookie adds the user's session cookie to the response
 | ||||||
| func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { | func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { | ||||||
| 	http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now())) | 	http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now())) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // LoadCookiedSession reads the user's authentication details from the request
 | ||||||
| func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { | func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { | ||||||
| 	var age time.Duration | 	var age time.Duration | ||||||
| 	c, err := req.Cookie(p.CookieName) | 	c, err := req.Cookie(p.CookieName) | ||||||
|  | @ -336,6 +368,7 @@ func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionSt | ||||||
| 	return session, age, nil | 	return session, age, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // SaveSession creates a new session cookie value and sets this on the response
 | ||||||
| func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error { | func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error { | ||||||
| 	value, err := p.provider.CookieForSession(s, p.CookieCipher) | 	value, err := p.provider.CookieForSession(s, p.CookieCipher) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -345,16 +378,19 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *p | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // RobotsTxt disallows scraping pages from the OAuthProxy
 | ||||||
| func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) { | func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) { | ||||||
| 	rw.WriteHeader(http.StatusOK) | 	rw.WriteHeader(http.StatusOK) | ||||||
| 	fmt.Fprintf(rw, "User-agent: *\nDisallow: /") | 	fmt.Fprintf(rw, "User-agent: *\nDisallow: /") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // PingPage responds 200 OK to requests
 | ||||||
| func (p *OAuthProxy) PingPage(rw http.ResponseWriter) { | func (p *OAuthProxy) PingPage(rw http.ResponseWriter) { | ||||||
| 	rw.WriteHeader(http.StatusOK) | 	rw.WriteHeader(http.StatusOK) | ||||||
| 	fmt.Fprintf(rw, "OK") | 	fmt.Fprintf(rw, "OK") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ErrorPage writes an error response
 | ||||||
| func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) { | func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) { | ||||||
| 	log.Printf("ErrorPage %d %s %s", code, title, message) | 	log.Printf("ErrorPage %d %s %s", code, title, message) | ||||||
| 	rw.WriteHeader(code) | 	rw.WriteHeader(code) | ||||||
|  | @ -370,16 +406,17 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m | ||||||
| 	p.templates.ExecuteTemplate(rw, "error.html", t) | 	p.templates.ExecuteTemplate(rw, "error.html", t) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // SignInPage writes the sing in template to the response
 | ||||||
| func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) { | func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) { | ||||||
| 	p.ClearSessionCookie(rw, req) | 	p.ClearSessionCookie(rw, req) | ||||||
| 	rw.WriteHeader(code) | 	rw.WriteHeader(code) | ||||||
| 
 | 
 | ||||||
| 	redirect_url := req.URL.RequestURI() | 	redirecURL := req.URL.RequestURI() | ||||||
| 	if req.Header.Get("X-Auth-Request-Redirect") != "" { | 	if req.Header.Get("X-Auth-Request-Redirect") != "" { | ||||||
| 		redirect_url = req.Header.Get("X-Auth-Request-Redirect") | 		redirecURL = req.Header.Get("X-Auth-Request-Redirect") | ||||||
| 	} | 	} | ||||||
| 	if redirect_url == p.SignInPath { | 	if redirecURL == p.SignInPath { | ||||||
| 		redirect_url = "/" | 		redirecURL = "/" | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	t := struct { | 	t := struct { | ||||||
|  | @ -394,7 +431,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code | ||||||
| 		ProviderName:  p.provider.Data().ProviderName, | 		ProviderName:  p.provider.Data().ProviderName, | ||||||
| 		SignInMessage: p.SignInMessage, | 		SignInMessage: p.SignInMessage, | ||||||
| 		CustomLogin:   p.displayCustomLoginForm(), | 		CustomLogin:   p.displayCustomLoginForm(), | ||||||
| 		Redirect:      redirect_url, | 		Redirect:      redirecURL, | ||||||
| 		Version:       VERSION, | 		Version:       VERSION, | ||||||
| 		ProxyPrefix:   p.ProxyPrefix, | 		ProxyPrefix:   p.ProxyPrefix, | ||||||
| 		Footer:        template.HTML(p.Footer), | 		Footer:        template.HTML(p.Footer), | ||||||
|  | @ -402,6 +439,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code | ||||||
| 	p.templates.ExecuteTemplate(rw, "sign_in.html", t) | 	p.templates.ExecuteTemplate(rw, "sign_in.html", t) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ManualSignIn handles basic auth logins to the proxy
 | ||||||
| func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (string, bool) { | func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (string, bool) { | ||||||
| 	if req.Method != "POST" || p.HtpasswdFile == nil { | 	if req.Method != "POST" || p.HtpasswdFile == nil { | ||||||
| 		return "", false | 		return "", false | ||||||
|  | @ -419,6 +457,8 @@ func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (st | ||||||
| 	return "", false | 	return "", false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // GetRedirect reads the query parameter to get the URL to redirect clients to
 | ||||||
|  | // once authenticated with the OAuthProxy
 | ||||||
| func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) { | func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) { | ||||||
| 	err = req.ParseForm() | 	err = req.ParseForm() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -433,11 +473,13 @@ func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // IsWhitelistedRequest is used to check if auth should be skipped for this request
 | ||||||
| func (p *OAuthProxy) IsWhitelistedRequest(req *http.Request) (ok bool) { | func (p *OAuthProxy) IsWhitelistedRequest(req *http.Request) (ok bool) { | ||||||
| 	isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" | 	isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS" | ||||||
| 	return isPreflightRequestAllowed || p.IsWhitelistedPath(req.URL.Path) | 	return isPreflightRequestAllowed || p.IsWhitelistedPath(req.URL.Path) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // IsWhitelistedPath is used to check if the request path is allowed without auth
 | ||||||
| func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) { | func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) { | ||||||
| 	for _, u := range p.compiledRegex { | 	for _, u := range p.compiledRegex { | ||||||
| 		ok = u.MatchString(path) | 		ok = u.MatchString(path) | ||||||
|  | @ -479,6 +521,7 @@ func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // SignIn serves a page prompting users to sign in
 | ||||||
| func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { | func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	redirect, err := p.GetRedirect(req) | 	redirect, err := p.GetRedirect(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -500,11 +543,13 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // SignOut sends a response to clear the authentication cookie
 | ||||||
| func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { | func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	p.ClearSessionCookie(rw, req) | 	p.ClearSessionCookie(rw, req) | ||||||
| 	http.Redirect(rw, req, "/", 302) | 	http.Redirect(rw, req, "/", 302) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // OAuthStart starts the OAuth2 authentication flow
 | ||||||
| func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { | func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	nonce, err := cookie.Nonce() | 	nonce, err := cookie.Nonce() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -521,6 +566,8 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), 302) | 	http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), 302) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // OAuthCallback is the OAuth2 authentication flow callback that finishes the
 | ||||||
|  | // OAuth2 authentication flow
 | ||||||
| func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	remoteAddr := getRemoteAddr(req) | 	remoteAddr := getRemoteAddr(req) | ||||||
| 
 | 
 | ||||||
|  | @ -582,6 +629,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // AuthenticateOnly checks whether the user is currently logged in
 | ||||||
| func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) { | func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	status := p.Authenticate(rw, req) | 	status := p.Authenticate(rw, req) | ||||||
| 	if status == http.StatusAccepted { | 	if status == http.StatusAccepted { | ||||||
|  | @ -591,6 +639,8 @@ func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Proxy proxies the user request if the user is authenticated else it prompts
 | ||||||
|  | // them to authenticate
 | ||||||
| func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { | func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	status := p.Authenticate(rw, req) | 	status := p.Authenticate(rw, req) | ||||||
| 	if status == http.StatusInternalServerError { | 	if status == http.StatusInternalServerError { | ||||||
|  | @ -607,6 +657,7 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Authenticate checks whether a user is authenticated
 | ||||||
| func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int { | func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int { | ||||||
| 	var saveSession, clearSession, revalidated bool | 	var saveSession, clearSession, revalidated bool | ||||||
| 	remoteAddr := getRemoteAddr(req) | 	remoteAddr := getRemoteAddr(req) | ||||||
|  | @ -620,7 +671,8 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int | ||||||
| 		saveSession = true | 		saveSession = true | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if ok, err := p.provider.RefreshSessionIfNeeded(session); err != nil { | 	var ok bool | ||||||
|  | 	if ok, err = p.provider.RefreshSessionIfNeeded(session); err != nil { | ||||||
| 		log.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session) | 		log.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session) | ||||||
| 		clearSession = true | 		clearSession = true | ||||||
| 		session = nil | 		session = nil | ||||||
|  | @ -653,7 +705,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if saveSession && session != nil { | 	if saveSession && session != nil { | ||||||
| 		err := p.SaveSession(rw, req, session) | 		err = p.SaveSession(rw, req, session) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Printf("%s %s", remoteAddr, err) | 			log.Printf("%s %s", remoteAddr, err) | ||||||
| 			return http.StatusInternalServerError | 			return http.StatusInternalServerError | ||||||
|  | @ -706,6 +758,8 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int | ||||||
| 	return http.StatusAccepted | 	return http.StatusAccepted | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // CheckBasicAuth checks the requests Authorization header for basic auth
 | ||||||
|  | // credentials and authenticates these against the proxies HtpasswdFile
 | ||||||
| func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) { | func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) { | ||||||
| 	if p.HtpasswdFile == nil { | 	if p.HtpasswdFile == nil { | ||||||
| 		return nil, nil | 		return nil, nil | ||||||
|  |  | ||||||
|  | @ -15,8 +15,8 @@ import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/bitly/oauth2_proxy/providers" |  | ||||||
| 	"github.com/mbland/hmacauth" | 	"github.com/mbland/hmacauth" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/providers" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -98,28 +98,28 @@ type TestProvider struct { | ||||||
| 	ValidToken   bool | 	ValidToken   bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewTestProvider(provider_url *url.URL, email_address string) *TestProvider { | func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { | ||||||
| 	return &TestProvider{ | 	return &TestProvider{ | ||||||
| 		ProviderData: &providers.ProviderData{ | 		ProviderData: &providers.ProviderData{ | ||||||
| 			ProviderName: "Test Provider", | 			ProviderName: "Test Provider", | ||||||
| 			LoginURL: &url.URL{ | 			LoginURL: &url.URL{ | ||||||
| 				Scheme: "http", | 				Scheme: "http", | ||||||
| 				Host:   provider_url.Host, | 				Host:   providerURL.Host, | ||||||
| 				Path:   "/oauth/authorize", | 				Path:   "/oauth/authorize", | ||||||
| 			}, | 			}, | ||||||
| 			RedeemURL: &url.URL{ | 			RedeemURL: &url.URL{ | ||||||
| 				Scheme: "http", | 				Scheme: "http", | ||||||
| 				Host:   provider_url.Host, | 				Host:   providerURL.Host, | ||||||
| 				Path:   "/oauth/token", | 				Path:   "/oauth/token", | ||||||
| 			}, | 			}, | ||||||
| 			ProfileURL: &url.URL{ | 			ProfileURL: &url.URL{ | ||||||
| 				Scheme: "http", | 				Scheme: "http", | ||||||
| 				Host:   provider_url.Host, | 				Host:   providerURL.Host, | ||||||
| 				Path:   "/api/v1/profile", | 				Path:   "/api/v1/profile", | ||||||
| 			}, | 			}, | ||||||
| 			Scope: "profile.email", | 			Scope: "profile.email", | ||||||
| 		}, | 		}, | ||||||
| 		EmailAddress: email_address, | 		EmailAddress: emailAddress, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -132,11 +132,10 @@ func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bo | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestBasicAuthPassword(t *testing.T) { | func TestBasicAuthPassword(t *testing.T) { | ||||||
| 	provider_server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 	providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
| 		log.Printf("%#v", r) | 		log.Printf("%#v", r) | ||||||
| 		url := r.URL | 		var payload string | ||||||
| 		payload := "" | 		switch r.URL.Path { | ||||||
| 		switch url.Path { |  | ||||||
| 		case "/oauth/token": | 		case "/oauth/token": | ||||||
| 			payload = `{"access_token": "my_auth_token"}` | 			payload = `{"access_token": "my_auth_token"}` | ||||||
| 		default: | 		default: | ||||||
|  | @ -149,7 +148,7 @@ func TestBasicAuthPassword(t *testing.T) { | ||||||
| 		w.Write([]byte(payload)) | 		w.Write([]byte(payload)) | ||||||
| 	})) | 	})) | ||||||
| 	opts := NewOptions() | 	opts := NewOptions() | ||||||
| 	opts.Upstreams = append(opts.Upstreams, provider_server.URL) | 	opts.Upstreams = append(opts.Upstreams, providerServer.URL) | ||||||
| 	// The CookieSecret must be 32 bytes in order to create the AES
 | 	// The CookieSecret must be 32 bytes in order to create the AES
 | ||||||
| 	// cipher.
 | 	// cipher.
 | ||||||
| 	opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" | 	opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" | ||||||
|  | @ -161,13 +160,13 @@ func TestBasicAuthPassword(t *testing.T) { | ||||||
| 	opts.BasicAuthPassword = "This is a secure password" | 	opts.BasicAuthPassword = "This is a secure password" | ||||||
| 	opts.Validate() | 	opts.Validate() | ||||||
| 
 | 
 | ||||||
| 	provider_url, _ := url.Parse(provider_server.URL) | 	providerURL, _ := url.Parse(providerServer.URL) | ||||||
| 	const email_address = "michael.bland@gsa.gov" | 	const emailAddress = "michael.bland@gsa.gov" | ||||||
| 	const user_name = "michael.bland" | 	const username = "michael.bland" | ||||||
| 
 | 
 | ||||||
| 	opts.provider = NewTestProvider(provider_url, email_address) | 	opts.provider = NewTestProvider(providerURL, emailAddress) | ||||||
| 	proxy := NewOAuthProxy(opts, func(email string) bool { | 	proxy := NewOAuthProxy(opts, func(email string) bool { | ||||||
| 		return email == email_address | 		return email == emailAddress | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	rw := httptest.NewRecorder() | 	rw := httptest.NewRecorder() | ||||||
|  | @ -182,10 +181,10 @@ func TestBasicAuthPassword(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	cookieName := proxy.CookieName | 	cookieName := proxy.CookieName | ||||||
| 	var value string | 	var value string | ||||||
| 	key_prefix := cookieName + "=" | 	keyPrefix := cookieName + "=" | ||||||
| 
 | 
 | ||||||
| 	for _, field := range strings.Split(cookie, "; ") { | 	for _, field := range strings.Split(cookie, "; ") { | ||||||
| 		value = strings.TrimPrefix(field, key_prefix) | 		value = strings.TrimPrefix(field, keyPrefix) | ||||||
| 		if value != field { | 		if value != field { | ||||||
| 			break | 			break | ||||||
| 		} else { | 		} else { | ||||||
|  | @ -206,15 +205,15 @@ func TestBasicAuthPassword(t *testing.T) { | ||||||
| 	rw = httptest.NewRecorder() | 	rw = httptest.NewRecorder() | ||||||
| 	proxy.ServeHTTP(rw, req) | 	proxy.ServeHTTP(rw, req) | ||||||
| 
 | 
 | ||||||
| 	expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(user_name+":"+opts.BasicAuthPassword)) | 	expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+opts.BasicAuthPassword)) | ||||||
| 	assert.Equal(t, expectedHeader, rw.Body.String()) | 	assert.Equal(t, expectedHeader, rw.Body.String()) | ||||||
| 	provider_server.Close() | 	providerServer.Close() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type PassAccessTokenTest struct { | type PassAccessTokenTest struct { | ||||||
| 	provider_server *httptest.Server | 	providerServer *httptest.Server | ||||||
| 	proxy           *OAuthProxy | 	proxy          *OAuthProxy | ||||||
| 	opts            *Options | 	opts           *Options | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type PassAccessTokenTestOptions struct { | type PassAccessTokenTestOptions struct { | ||||||
|  | @ -224,12 +223,11 @@ type PassAccessTokenTestOptions struct { | ||||||
| func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest { | func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest { | ||||||
| 	t := &PassAccessTokenTest{} | 	t := &PassAccessTokenTest{} | ||||||
| 
 | 
 | ||||||
| 	t.provider_server = httptest.NewServer( | 	t.providerServer = httptest.NewServer( | ||||||
| 		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
| 			log.Printf("%#v", r) | 			log.Printf("%#v", r) | ||||||
| 			url := r.URL | 			var payload string | ||||||
| 			payload := "" | 			switch r.URL.Path { | ||||||
| 			switch url.Path { |  | ||||||
| 			case "/oauth/token": | 			case "/oauth/token": | ||||||
| 				payload = `{"access_token": "my_auth_token"}` | 				payload = `{"access_token": "my_auth_token"}` | ||||||
| 			default: | 			default: | ||||||
|  | @ -243,7 +241,7 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes | ||||||
| 		})) | 		})) | ||||||
| 
 | 
 | ||||||
| 	t.opts = NewOptions() | 	t.opts = NewOptions() | ||||||
| 	t.opts.Upstreams = append(t.opts.Upstreams, t.provider_server.URL) | 	t.opts.Upstreams = append(t.opts.Upstreams, t.providerServer.URL) | ||||||
| 	// The CookieSecret must be 32 bytes in order to create the AES
 | 	// The CookieSecret must be 32 bytes in order to create the AES
 | ||||||
| 	// cipher.
 | 	// cipher.
 | ||||||
| 	t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" | 	t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" | ||||||
|  | @ -253,21 +251,21 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes | ||||||
| 	t.opts.PassAccessToken = opts.PassAccessToken | 	t.opts.PassAccessToken = opts.PassAccessToken | ||||||
| 	t.opts.Validate() | 	t.opts.Validate() | ||||||
| 
 | 
 | ||||||
| 	provider_url, _ := url.Parse(t.provider_server.URL) | 	providerURL, _ := url.Parse(t.providerServer.URL) | ||||||
| 	const email_address = "michael.bland@gsa.gov" | 	const emailAddress = "michael.bland@gsa.gov" | ||||||
| 
 | 
 | ||||||
| 	t.opts.provider = NewTestProvider(provider_url, email_address) | 	t.opts.provider = NewTestProvider(providerURL, emailAddress) | ||||||
| 	t.proxy = NewOAuthProxy(t.opts, func(email string) bool { | 	t.proxy = NewOAuthProxy(t.opts, func(email string) bool { | ||||||
| 		return email == email_address | 		return email == emailAddress | ||||||
| 	}) | 	}) | ||||||
| 	return t | 	return t | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (pat_test *PassAccessTokenTest) Close() { | func (patTest *PassAccessTokenTest) Close() { | ||||||
| 	pat_test.provider_server.Close() | 	patTest.providerServer.Close() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int, | func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, | ||||||
| 	cookie string) { | 	cookie string) { | ||||||
| 	rw := httptest.NewRecorder() | 	rw := httptest.NewRecorder() | ||||||
| 	req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", | 	req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", | ||||||
|  | @ -275,18 +273,18 @@ func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int, | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return 0, "" | 		return 0, "" | ||||||
| 	} | 	} | ||||||
| 	req.AddCookie(pat_test.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now())) | 	req.AddCookie(patTest.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now())) | ||||||
| 	pat_test.proxy.ServeHTTP(rw, req) | 	patTest.proxy.ServeHTTP(rw, req) | ||||||
| 	return rw.Code, rw.HeaderMap["Set-Cookie"][1] | 	return rw.Code, rw.HeaderMap["Set-Cookie"][1] | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) { | func (patTest *PassAccessTokenTest) getRootEndpoint(cookie string) (httpCode int, accessToken string) { | ||||||
| 	cookieName := pat_test.proxy.CookieName | 	cookieName := patTest.proxy.CookieName | ||||||
| 	var value string | 	var value string | ||||||
| 	key_prefix := cookieName + "=" | 	keyPrefix := cookieName + "=" | ||||||
| 
 | 
 | ||||||
| 	for _, field := range strings.Split(cookie, "; ") { | 	for _, field := range strings.Split(cookie, "; ") { | ||||||
| 		value = strings.TrimPrefix(field, key_prefix) | 		value = strings.TrimPrefix(field, keyPrefix) | ||||||
| 		if value != field { | 		if value != field { | ||||||
| 			break | 			break | ||||||
| 		} else { | 		} else { | ||||||
|  | @ -310,18 +308,18 @@ func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code i | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	rw := httptest.NewRecorder() | 	rw := httptest.NewRecorder() | ||||||
| 	pat_test.proxy.ServeHTTP(rw, req) | 	patTest.proxy.ServeHTTP(rw, req) | ||||||
| 	return rw.Code, rw.Body.String() | 	return rw.Code, rw.Body.String() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestForwardAccessTokenUpstream(t *testing.T) { | func TestForwardAccessTokenUpstream(t *testing.T) { | ||||||
| 	pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ | 	patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{ | ||||||
| 		PassAccessToken: true, | 		PassAccessToken: true, | ||||||
| 	}) | 	}) | ||||||
| 	defer pat_test.Close() | 	defer patTest.Close() | ||||||
| 
 | 
 | ||||||
| 	// A successful validation will redirect and set the auth cookie.
 | 	// A successful validation will redirect and set the auth cookie.
 | ||||||
| 	code, cookie := pat_test.getCallbackEndpoint() | 	code, cookie := patTest.getCallbackEndpoint() | ||||||
| 	if code != 302 { | 	if code != 302 { | ||||||
| 		t.Fatalf("expected 302; got %d", code) | 		t.Fatalf("expected 302; got %d", code) | ||||||
| 	} | 	} | ||||||
|  | @ -330,7 +328,7 @@ func TestForwardAccessTokenUpstream(t *testing.T) { | ||||||
| 	// Now we make a regular request; the access_token from the cookie is
 | 	// Now we make a regular request; the access_token from the cookie is
 | ||||||
| 	// forwarded as the "X-Forwarded-Access-Token" header. The token is
 | 	// forwarded as the "X-Forwarded-Access-Token" header. The token is
 | ||||||
| 	// read by the test provider server and written in the response body.
 | 	// read by the test provider server and written in the response body.
 | ||||||
| 	code, payload := pat_test.getRootEndpoint(cookie) | 	code, payload := patTest.getRootEndpoint(cookie) | ||||||
| 	if code != 200 { | 	if code != 200 { | ||||||
| 		t.Fatalf("expected 200; got %d", code) | 		t.Fatalf("expected 200; got %d", code) | ||||||
| 	} | 	} | ||||||
|  | @ -338,13 +336,13 @@ func TestForwardAccessTokenUpstream(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestDoNotForwardAccessTokenUpstream(t *testing.T) { | func TestDoNotForwardAccessTokenUpstream(t *testing.T) { | ||||||
| 	pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ | 	patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{ | ||||||
| 		PassAccessToken: false, | 		PassAccessToken: false, | ||||||
| 	}) | 	}) | ||||||
| 	defer pat_test.Close() | 	defer patTest.Close() | ||||||
| 
 | 
 | ||||||
| 	// A successful validation will redirect and set the auth cookie.
 | 	// A successful validation will redirect and set the auth cookie.
 | ||||||
| 	code, cookie := pat_test.getCallbackEndpoint() | 	code, cookie := patTest.getCallbackEndpoint() | ||||||
| 	if code != 302 { | 	if code != 302 { | ||||||
| 		t.Fatalf("expected 302; got %d", code) | 		t.Fatalf("expected 302; got %d", code) | ||||||
| 	} | 	} | ||||||
|  | @ -352,7 +350,7 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	// Now we make a regular request, but the access token header should
 | 	// Now we make a regular request, but the access token header should
 | ||||||
| 	// not be present.
 | 	// not be present.
 | ||||||
| 	code, payload := pat_test.getRootEndpoint(cookie) | 	code, payload := patTest.getRootEndpoint(cookie) | ||||||
| 	if code != 200 { | 	if code != 200 { | ||||||
| 		t.Fatalf("expected 200; got %d", code) | 		t.Fatalf("expected 200; got %d", code) | ||||||
| 	} | 	} | ||||||
|  | @ -360,49 +358,49 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type SignInPageTest struct { | type SignInPageTest struct { | ||||||
| 	opts                    *Options | 	opts                 *Options | ||||||
| 	proxy                   *OAuthProxy | 	proxy                *OAuthProxy | ||||||
| 	sign_in_regexp          *regexp.Regexp | 	signInRegexp         *regexp.Regexp | ||||||
| 	sign_in_provider_regexp *regexp.Regexp | 	signInProviderRegexp *regexp.Regexp | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">` | const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">` | ||||||
| const signInSkipProvider = `>Found<` | const signInSkipProvider = `>Found<` | ||||||
| 
 | 
 | ||||||
| func NewSignInPageTest(skipProvider bool) *SignInPageTest { | func NewSignInPageTest(skipProvider bool) *SignInPageTest { | ||||||
| 	var sip_test SignInPageTest | 	var sipTest SignInPageTest | ||||||
| 
 | 
 | ||||||
| 	sip_test.opts = NewOptions() | 	sipTest.opts = NewOptions() | ||||||
| 	sip_test.opts.CookieSecret = "foobar" | 	sipTest.opts.CookieSecret = "foobar" | ||||||
| 	sip_test.opts.ClientID = "bazquux" | 	sipTest.opts.ClientID = "bazquux" | ||||||
| 	sip_test.opts.ClientSecret = "xyzzyplugh" | 	sipTest.opts.ClientSecret = "xyzzyplugh" | ||||||
| 	sip_test.opts.SkipProviderButton = skipProvider | 	sipTest.opts.SkipProviderButton = skipProvider | ||||||
| 	sip_test.opts.Validate() | 	sipTest.opts.Validate() | ||||||
| 
 | 
 | ||||||
| 	sip_test.proxy = NewOAuthProxy(sip_test.opts, func(email string) bool { | 	sipTest.proxy = NewOAuthProxy(sipTest.opts, func(email string) bool { | ||||||
| 		return true | 		return true | ||||||
| 	}) | 	}) | ||||||
| 	sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern) | 	sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern) | ||||||
| 	sip_test.sign_in_provider_regexp = regexp.MustCompile(signInSkipProvider) | 	sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider) | ||||||
| 
 | 
 | ||||||
| 	return &sip_test | 	return &sipTest | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (sip_test *SignInPageTest) GetEndpoint(endpoint string) (int, string) { | func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) { | ||||||
| 	rw := httptest.NewRecorder() | 	rw := httptest.NewRecorder() | ||||||
| 	req, _ := http.NewRequest("GET", endpoint, strings.NewReader("")) | 	req, _ := http.NewRequest("GET", endpoint, strings.NewReader("")) | ||||||
| 	sip_test.proxy.ServeHTTP(rw, req) | 	sipTest.proxy.ServeHTTP(rw, req) | ||||||
| 	return rw.Code, rw.Body.String() | 	return rw.Code, rw.Body.String() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSignInPageIncludesTargetRedirect(t *testing.T) { | func TestSignInPageIncludesTargetRedirect(t *testing.T) { | ||||||
| 	sip_test := NewSignInPageTest(false) | 	sipTest := NewSignInPageTest(false) | ||||||
| 	const endpoint = "/some/random/endpoint" | 	const endpoint = "/some/random/endpoint" | ||||||
| 
 | 
 | ||||||
| 	code, body := sip_test.GetEndpoint(endpoint) | 	code, body := sipTest.GetEndpoint(endpoint) | ||||||
| 	assert.Equal(t, 403, code) | 	assert.Equal(t, 403, code) | ||||||
| 
 | 
 | ||||||
| 	match := sip_test.sign_in_regexp.FindStringSubmatch(body) | 	match := sipTest.signInRegexp.FindStringSubmatch(body) | ||||||
| 	if match == nil { | 	if match == nil { | ||||||
| 		t.Fatal("Did not find pattern in body: " + | 		t.Fatal("Did not find pattern in body: " + | ||||||
| 			signInRedirectPattern + "\nBody:\n" + body) | 			signInRedirectPattern + "\nBody:\n" + body) | ||||||
|  | @ -414,11 +412,11 @@ func TestSignInPageIncludesTargetRedirect(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { | func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { | ||||||
| 	sip_test := NewSignInPageTest(false) | 	sipTest := NewSignInPageTest(false) | ||||||
| 	code, body := sip_test.GetEndpoint("/oauth2/sign_in") | 	code, body := sipTest.GetEndpoint("/oauth2/sign_in") | ||||||
| 	assert.Equal(t, 200, code) | 	assert.Equal(t, 200, code) | ||||||
| 
 | 
 | ||||||
| 	match := sip_test.sign_in_regexp.FindStringSubmatch(body) | 	match := sipTest.signInRegexp.FindStringSubmatch(body) | ||||||
| 	if match == nil { | 	if match == nil { | ||||||
| 		t.Fatal("Did not find pattern in body: " + | 		t.Fatal("Did not find pattern in body: " + | ||||||
| 			signInRedirectPattern + "\nBody:\n" + body) | 			signInRedirectPattern + "\nBody:\n" + body) | ||||||
|  | @ -429,13 +427,13 @@ func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSignInPageSkipProvider(t *testing.T) { | func TestSignInPageSkipProvider(t *testing.T) { | ||||||
| 	sip_test := NewSignInPageTest(true) | 	sipTest := NewSignInPageTest(true) | ||||||
| 	const endpoint = "/some/random/endpoint" | 	const endpoint = "/some/random/endpoint" | ||||||
| 
 | 
 | ||||||
| 	code, body := sip_test.GetEndpoint(endpoint) | 	code, body := sipTest.GetEndpoint(endpoint) | ||||||
| 	assert.Equal(t, 302, code) | 	assert.Equal(t, 302, code) | ||||||
| 
 | 
 | ||||||
| 	match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body) | 	match := sipTest.signInProviderRegexp.FindStringSubmatch(body) | ||||||
| 	if match == nil { | 	if match == nil { | ||||||
| 		t.Fatal("Did not find pattern in body: " + | 		t.Fatal("Did not find pattern in body: " + | ||||||
| 			signInSkipProvider + "\nBody:\n" + body) | 			signInSkipProvider + "\nBody:\n" + body) | ||||||
|  | @ -443,13 +441,13 @@ func TestSignInPageSkipProvider(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSignInPageSkipProviderDirect(t *testing.T) { | func TestSignInPageSkipProviderDirect(t *testing.T) { | ||||||
| 	sip_test := NewSignInPageTest(true) | 	sipTest := NewSignInPageTest(true) | ||||||
| 	const endpoint = "/sign_in" | 	const endpoint = "/sign_in" | ||||||
| 
 | 
 | ||||||
| 	code, body := sip_test.GetEndpoint(endpoint) | 	code, body := sipTest.GetEndpoint(endpoint) | ||||||
| 	assert.Equal(t, 302, code) | 	assert.Equal(t, 302, code) | ||||||
| 
 | 
 | ||||||
| 	match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body) | 	match := sipTest.signInProviderRegexp.FindStringSubmatch(body) | ||||||
| 	if match == nil { | 	if match == nil { | ||||||
| 		t.Fatal("Did not find pattern in body: " + | 		t.Fatal("Did not find pattern in body: " + | ||||||
| 			signInSkipProvider + "\nBody:\n" + body) | 			signInSkipProvider + "\nBody:\n" + body) | ||||||
|  | @ -457,50 +455,50 @@ func TestSignInPageSkipProviderDirect(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type ProcessCookieTest struct { | type ProcessCookieTest struct { | ||||||
| 	opts          *Options | 	opts         *Options | ||||||
| 	proxy         *OAuthProxy | 	proxy        *OAuthProxy | ||||||
| 	rw            *httptest.ResponseRecorder | 	rw           *httptest.ResponseRecorder | ||||||
| 	req           *http.Request | 	req          *http.Request | ||||||
| 	provider      TestProvider | 	provider     TestProvider | ||||||
| 	response_code int | 	responseCode int | ||||||
| 	validate_user bool | 	validateUser bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type ProcessCookieTestOpts struct { | type ProcessCookieTestOpts struct { | ||||||
| 	provider_validate_cookie_response bool | 	providerValidateCookieResponse bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { | func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { | ||||||
| 	var pc_test ProcessCookieTest | 	var pcTest ProcessCookieTest | ||||||
| 
 | 
 | ||||||
| 	pc_test.opts = NewOptions() | 	pcTest.opts = NewOptions() | ||||||
| 	pc_test.opts.ClientID = "bazquux" | 	pcTest.opts.ClientID = "bazquux" | ||||||
| 	pc_test.opts.ClientSecret = "xyzzyplugh" | 	pcTest.opts.ClientSecret = "xyzzyplugh" | ||||||
| 	pc_test.opts.CookieSecret = "0123456789abcdefabcd" | 	pcTest.opts.CookieSecret = "0123456789abcdefabcd" | ||||||
| 	// First, set the CookieRefresh option so proxy.AesCipher is created,
 | 	// First, set the CookieRefresh option so proxy.AesCipher is created,
 | ||||||
| 	// needed to encrypt the access_token.
 | 	// needed to encrypt the access_token.
 | ||||||
| 	pc_test.opts.CookieRefresh = time.Hour | 	pcTest.opts.CookieRefresh = time.Hour | ||||||
| 	pc_test.opts.Validate() | 	pcTest.opts.Validate() | ||||||
| 
 | 
 | ||||||
| 	pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool { | 	pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool { | ||||||
| 		return pc_test.validate_user | 		return pcTest.validateUser | ||||||
| 	}) | 	}) | ||||||
| 	pc_test.proxy.provider = &TestProvider{ | 	pcTest.proxy.provider = &TestProvider{ | ||||||
| 		ValidToken: opts.provider_validate_cookie_response, | 		ValidToken: opts.providerValidateCookieResponse, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Now, zero-out proxy.CookieRefresh for the cases that don't involve
 | 	// Now, zero-out proxy.CookieRefresh for the cases that don't involve
 | ||||||
| 	// access_token validation.
 | 	// access_token validation.
 | ||||||
| 	pc_test.proxy.CookieRefresh = time.Duration(0) | 	pcTest.proxy.CookieRefresh = time.Duration(0) | ||||||
| 	pc_test.rw = httptest.NewRecorder() | 	pcTest.rw = httptest.NewRecorder() | ||||||
| 	pc_test.req, _ = http.NewRequest("GET", "/", strings.NewReader("")) | 	pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader("")) | ||||||
| 	pc_test.validate_user = true | 	pcTest.validateUser = true | ||||||
| 	return &pc_test | 	return &pcTest | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewProcessCookieTestWithDefaults() *ProcessCookieTest { | func NewProcessCookieTestWithDefaults() *ProcessCookieTest { | ||||||
| 	return NewProcessCookieTest(ProcessCookieTestOpts{ | 	return NewProcessCookieTest(ProcessCookieTestOpts{ | ||||||
| 		provider_validate_cookie_response: true, | 		providerValidateCookieResponse: true, | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -522,12 +520,12 @@ func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time. | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestLoadCookiedSession(t *testing.T) { | func TestLoadCookiedSession(t *testing.T) { | ||||||
| 	pc_test := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithDefaults() | ||||||
| 
 | 
 | ||||||
| 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||||
| 	pc_test.SaveSession(startSession, time.Now()) | 	pcTest.SaveSession(startSession, time.Now()) | ||||||
| 
 | 
 | ||||||
| 	session, _, err := pc_test.LoadCookiedSession() | 	session, _, err := pcTest.LoadCookiedSession() | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, startSession.Email, session.Email) | 	assert.Equal(t, startSession.Email, session.Email) | ||||||
| 	assert.Equal(t, "michael.bland", session.User) | 	assert.Equal(t, "michael.bland", session.User) | ||||||
|  | @ -535,9 +533,9 @@ func TestLoadCookiedSession(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestProcessCookieNoCookieError(t *testing.T) { | func TestProcessCookieNoCookieError(t *testing.T) { | ||||||
| 	pc_test := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithDefaults() | ||||||
| 
 | 
 | ||||||
| 	session, _, err := pc_test.LoadCookiedSession() | 	session, _, err := pcTest.LoadCookiedSession() | ||||||
| 	assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) | 	assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) | ||||||
| 	if session != nil { | 	if session != nil { | ||||||
| 		t.Errorf("expected nil session. got %#v", session) | 		t.Errorf("expected nil session. got %#v", session) | ||||||
|  | @ -545,14 +543,14 @@ func TestProcessCookieNoCookieError(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestProcessCookieRefreshNotSet(t *testing.T) { | func TestProcessCookieRefreshNotSet(t *testing.T) { | ||||||
| 	pc_test := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithDefaults() | ||||||
| 	pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour | 	pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour | ||||||
| 	reference := time.Now().Add(time.Duration(-2) * time.Hour) | 	reference := time.Now().Add(time.Duration(-2) * time.Hour) | ||||||
| 
 | 
 | ||||||
| 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||||
| 	pc_test.SaveSession(startSession, reference) | 	pcTest.SaveSession(startSession, reference) | ||||||
| 
 | 
 | ||||||
| 	session, age, err := pc_test.LoadCookiedSession() | 	session, age, err := pcTest.LoadCookiedSession() | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	if age < time.Duration(-2)*time.Hour { | 	if age < time.Duration(-2)*time.Hour { | ||||||
| 		t.Errorf("cookie too young %v", age) | 		t.Errorf("cookie too young %v", age) | ||||||
|  | @ -561,13 +559,13 @@ func TestProcessCookieRefreshNotSet(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestProcessCookieFailIfCookieExpired(t *testing.T) { | func TestProcessCookieFailIfCookieExpired(t *testing.T) { | ||||||
| 	pc_test := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithDefaults() | ||||||
| 	pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour | 	pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour | ||||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||||
| 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||||
| 	pc_test.SaveSession(startSession, reference) | 	pcTest.SaveSession(startSession, reference) | ||||||
| 
 | 
 | ||||||
| 	session, _, err := pc_test.LoadCookiedSession() | 	session, _, err := pcTest.LoadCookiedSession() | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| 	if session != nil { | 	if session != nil { | ||||||
| 		t.Errorf("expected nil session %#v", session) | 		t.Errorf("expected nil session %#v", session) | ||||||
|  | @ -575,14 +573,14 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { | func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { | ||||||
| 	pc_test := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithDefaults() | ||||||
| 	pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour | 	pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour | ||||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||||
| 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||||
| 	pc_test.SaveSession(startSession, reference) | 	pcTest.SaveSession(startSession, reference) | ||||||
| 
 | 
 | ||||||
| 	pc_test.proxy.CookieRefresh = time.Hour | 	pcTest.proxy.CookieRefresh = time.Hour | ||||||
| 	session, _, err := pc_test.LoadCookiedSession() | 	session, _, err := pcTest.LoadCookiedSession() | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| 	if session != nil { | 	if session != nil { | ||||||
| 		t.Errorf("expected nil session %#v", session) | 		t.Errorf("expected nil session %#v", session) | ||||||
|  | @ -590,10 +588,10 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewAuthOnlyEndpointTest() *ProcessCookieTest { | func NewAuthOnlyEndpointTest() *ProcessCookieTest { | ||||||
| 	pc_test := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithDefaults() | ||||||
| 	pc_test.req, _ = http.NewRequest("GET", | 	pcTest.req, _ = http.NewRequest("GET", | ||||||
| 		pc_test.opts.ProxyPrefix+"/auth", nil) | 		pcTest.opts.ProxyPrefix+"/auth", nil) | ||||||
| 	return pc_test | 	return pcTest | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestAuthOnlyEndpointAccepted(t *testing.T) { | func TestAuthOnlyEndpointAccepted(t *testing.T) { | ||||||
|  | @ -636,7 +634,7 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { | ||||||
| 	startSession := &providers.SessionState{ | 	startSession := &providers.SessionState{ | ||||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||||
| 	test.SaveSession(startSession, time.Now()) | 	test.SaveSession(startSession, time.Now()) | ||||||
| 	test.validate_user = false | 	test.validateUser = false | ||||||
| 
 | 
 | ||||||
| 	test.proxy.ServeHTTP(test.rw, test.req) | 	test.proxy.ServeHTTP(test.rw, test.req) | ||||||
| 	assert.Equal(t, http.StatusUnauthorized, test.rw.Code) | 	assert.Equal(t, http.StatusUnauthorized, test.rw.Code) | ||||||
|  | @ -645,33 +643,33 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { | func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { | ||||||
| 	var pc_test ProcessCookieTest | 	var pcTest ProcessCookieTest | ||||||
| 
 | 
 | ||||||
| 	pc_test.opts = NewOptions() | 	pcTest.opts = NewOptions() | ||||||
| 	pc_test.opts.SetXAuthRequest = true | 	pcTest.opts.SetXAuthRequest = true | ||||||
| 	pc_test.opts.Validate() | 	pcTest.opts.Validate() | ||||||
| 
 | 
 | ||||||
| 	pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool { | 	pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool { | ||||||
| 		return pc_test.validate_user | 		return pcTest.validateUser | ||||||
| 	}) | 	}) | ||||||
| 	pc_test.proxy.provider = &TestProvider{ | 	pcTest.proxy.provider = &TestProvider{ | ||||||
| 		ValidToken: true, | 		ValidToken: true, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	pc_test.validate_user = true | 	pcTest.validateUser = true | ||||||
| 
 | 
 | ||||||
| 	pc_test.rw = httptest.NewRecorder() | 	pcTest.rw = httptest.NewRecorder() | ||||||
| 	pc_test.req, _ = http.NewRequest("GET", | 	pcTest.req, _ = http.NewRequest("GET", | ||||||
| 		pc_test.opts.ProxyPrefix+"/auth", nil) | 		pcTest.opts.ProxyPrefix+"/auth", nil) | ||||||
| 
 | 
 | ||||||
| 	startSession := &providers.SessionState{ | 	startSession := &providers.SessionState{ | ||||||
| 		User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} | 		User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} | ||||||
| 	pc_test.SaveSession(startSession, time.Now()) | 	pcTest.SaveSession(startSession, time.Now()) | ||||||
| 
 | 
 | ||||||
| 	pc_test.proxy.ServeHTTP(pc_test.rw, pc_test.req) | 	pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) | ||||||
| 	assert.Equal(t, http.StatusAccepted, pc_test.rw.Code) | 	assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) | ||||||
| 	assert.Equal(t, "oauth_user", pc_test.rw.HeaderMap["X-Auth-Request-User"][0]) | 	assert.Equal(t, "oauth_user", pcTest.rw.HeaderMap["X-Auth-Request-User"][0]) | ||||||
| 	assert.Equal(t, "oauth_user@example.com", pc_test.rw.HeaderMap["X-Auth-Request-Email"][0]) | 	assert.Equal(t, "oauth_user@example.com", pcTest.rw.HeaderMap["X-Auth-Request-Email"][0]) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestAuthSkippedForPreflightRequests(t *testing.T) { | func TestAuthSkippedForPreflightRequests(t *testing.T) { | ||||||
|  | @ -689,8 +687,8 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) { | ||||||
| 	opts.SkipAuthPreflight = true | 	opts.SkipAuthPreflight = true | ||||||
| 	opts.Validate() | 	opts.Validate() | ||||||
| 
 | 
 | ||||||
| 	upstream_url, _ := url.Parse(upstream.URL) | 	upstreamURL, _ := url.Parse(upstream.URL) | ||||||
| 	opts.provider = NewTestProvider(upstream_url, "") | 	opts.provider = NewTestProvider(upstreamURL, "") | ||||||
| 
 | 
 | ||||||
| 	proxy := NewOAuthProxy(opts, func(string) bool { return false }) | 	proxy := NewOAuthProxy(opts, func(string) bool { return false }) | ||||||
| 	rw := httptest.NewRecorder() | 	rw := httptest.NewRecorder() | ||||||
|  | @ -723,7 +721,7 @@ func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Req | ||||||
| type SignatureTest struct { | type SignatureTest struct { | ||||||
| 	opts          *Options | 	opts          *Options | ||||||
| 	upstream      *httptest.Server | 	upstream      *httptest.Server | ||||||
| 	upstream_host string | 	upstreamHost  string | ||||||
| 	provider      *httptest.Server | 	provider      *httptest.Server | ||||||
| 	header        http.Header | 	header        http.Header | ||||||
| 	rw            *httptest.ResponseRecorder | 	rw            *httptest.ResponseRecorder | ||||||
|  | @ -740,20 +738,20 @@ func NewSignatureTest() *SignatureTest { | ||||||
| 	authenticator := &SignatureAuthenticator{} | 	authenticator := &SignatureAuthenticator{} | ||||||
| 	upstream := httptest.NewServer( | 	upstream := httptest.NewServer( | ||||||
| 		http.HandlerFunc(authenticator.Authenticate)) | 		http.HandlerFunc(authenticator.Authenticate)) | ||||||
| 	upstream_url, _ := url.Parse(upstream.URL) | 	upstreamURL, _ := url.Parse(upstream.URL) | ||||||
| 	opts.Upstreams = append(opts.Upstreams, upstream.URL) | 	opts.Upstreams = append(opts.Upstreams, upstream.URL) | ||||||
| 
 | 
 | ||||||
| 	providerHandler := func(w http.ResponseWriter, r *http.Request) { | 	providerHandler := func(w http.ResponseWriter, r *http.Request) { | ||||||
| 		w.Write([]byte(`{"access_token": "my_auth_token"}`)) | 		w.Write([]byte(`{"access_token": "my_auth_token"}`)) | ||||||
| 	} | 	} | ||||||
| 	provider := httptest.NewServer(http.HandlerFunc(providerHandler)) | 	provider := httptest.NewServer(http.HandlerFunc(providerHandler)) | ||||||
| 	provider_url, _ := url.Parse(provider.URL) | 	providerURL, _ := url.Parse(provider.URL) | ||||||
| 	opts.provider = NewTestProvider(provider_url, "mbland@acm.org") | 	opts.provider = NewTestProvider(providerURL, "mbland@acm.org") | ||||||
| 
 | 
 | ||||||
| 	return &SignatureTest{ | 	return &SignatureTest{ | ||||||
| 		opts, | 		opts, | ||||||
| 		upstream, | 		upstream, | ||||||
| 		upstream_url.Host, | 		upstreamURL.Host, | ||||||
| 		provider, | 		provider, | ||||||
| 		make(http.Header), | 		make(http.Header), | ||||||
| 		httptest.NewRecorder(), | 		httptest.NewRecorder(), | ||||||
|  |  | ||||||
							
								
								
									
										40
									
								
								options.go
								
								
								
								
							
							
						
						
									
										40
									
								
								options.go
								
								
								
								
							|  | @ -13,16 +13,17 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/bitly/oauth2_proxy/providers" |  | ||||||
| 	oidc "github.com/coreos/go-oidc" | 	oidc "github.com/coreos/go-oidc" | ||||||
| 	"github.com/mbland/hmacauth" | 	"github.com/mbland/hmacauth" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/providers" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Configuration Options that can be set by Command Line Flag, or Config File
 | // Options holds Configuration Options that can be set by Command Line Flag,
 | ||||||
|  | // or Config File
 | ||||||
| type Options struct { | type Options struct { | ||||||
| 	ProxyPrefix  string `flag:"proxy-prefix" cfg:"proxy-prefix"` | 	ProxyPrefix  string `flag:"proxy-prefix" cfg:"proxy-prefix"` | ||||||
| 	HttpAddress  string `flag:"http-address" cfg:"http_address"` | 	HTTPAddress  string `flag:"http-address" cfg:"http_address"` | ||||||
| 	HttpsAddress string `flag:"https-address" cfg:"https_address"` | 	HTTPSAddress string `flag:"https-address" cfg:"https_address"` | ||||||
| 	RedirectURL  string `flag:"redirect-url" cfg:"redirect_url"` | 	RedirectURL  string `flag:"redirect-url" cfg:"redirect_url"` | ||||||
| 	ClientID     string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"` | 	ClientID     string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"` | ||||||
| 	ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"` | 	ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"` | ||||||
|  | @ -48,7 +49,7 @@ type Options struct { | ||||||
| 	CookieExpire   time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"` | 	CookieExpire   time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"` | ||||||
| 	CookieRefresh  time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"` | 	CookieRefresh  time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"` | ||||||
| 	CookieSecure   bool          `flag:"cookie-secure" cfg:"cookie_secure"` | 	CookieSecure   bool          `flag:"cookie-secure" cfg:"cookie_secure"` | ||||||
| 	CookieHttpOnly bool          `flag:"cookie-httponly" cfg:"cookie_httponly"` | 	CookieHTTPOnly bool          `flag:"cookie-httponly" cfg:"cookie_httponly"` | ||||||
| 
 | 
 | ||||||
| 	Upstreams             []string `flag:"upstream" cfg:"upstreams"` | 	Upstreams             []string `flag:"upstream" cfg:"upstreams"` | ||||||
| 	SkipAuthRegex         []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` | 	SkipAuthRegex         []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` | ||||||
|  | @ -88,20 +89,22 @@ type Options struct { | ||||||
| 	oidcVerifier  *oidc.IDTokenVerifier | 	oidcVerifier  *oidc.IDTokenVerifier | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // SignatureData holds hmacauth signature hash and key
 | ||||||
| type SignatureData struct { | type SignatureData struct { | ||||||
| 	hash crypto.Hash | 	hash crypto.Hash | ||||||
| 	key  string | 	key  string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // NewOptions constructs a new Options with defaulted values
 | ||||||
| func NewOptions() *Options { | func NewOptions() *Options { | ||||||
| 	return &Options{ | 	return &Options{ | ||||||
| 		ProxyPrefix:          "/oauth2", | 		ProxyPrefix:          "/oauth2", | ||||||
| 		HttpAddress:          "127.0.0.1:4180", | 		HTTPAddress:          "127.0.0.1:4180", | ||||||
| 		HttpsAddress:         ":443", | 		HTTPSAddress:         ":443", | ||||||
| 		DisplayHtpasswdForm:  true, | 		DisplayHtpasswdForm:  true, | ||||||
| 		CookieName:           "_oauth2_proxy", | 		CookieName:           "_oauth2_proxy", | ||||||
| 		CookieSecure:         true, | 		CookieSecure:         true, | ||||||
| 		CookieHttpOnly:       true, | 		CookieHTTPOnly:       true, | ||||||
| 		CookieExpire:         time.Duration(168) * time.Hour, | 		CookieExpire:         time.Duration(168) * time.Hour, | ||||||
| 		CookieRefresh:        time.Duration(0), | 		CookieRefresh:        time.Duration(0), | ||||||
| 		SetXAuthRequest:      false, | 		SetXAuthRequest:      false, | ||||||
|  | @ -116,15 +119,17 @@ func NewOptions() *Options { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func parseURL(to_parse string, urltype string, msgs []string) (*url.URL, []string) { | func parseURL(toParse string, urltype string, msgs []string) (*url.URL, []string) { | ||||||
| 	parsed, err := url.Parse(to_parse) | 	parsed, err := url.Parse(toParse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, append(msgs, fmt.Sprintf( | 		return nil, append(msgs, fmt.Sprintf( | ||||||
| 			"error parsing %s-url=%q %s", urltype, to_parse, err)) | 			"error parsing %s-url=%q %s", urltype, toParse, err)) | ||||||
| 	} | 	} | ||||||
| 	return parsed, msgs | 	return parsed, msgs | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Validate checks that required options are set and validates those that they
 | ||||||
|  | // are of the correct format
 | ||||||
| func (o *Options) Validate() error { | func (o *Options) Validate() error { | ||||||
| 	if o.SSLInsecureSkipVerify { | 	if o.SSLInsecureSkipVerify { | ||||||
| 		// TODO: Accept a certificate bundle.
 | 		// TODO: Accept a certificate bundle.
 | ||||||
|  | @ -190,17 +195,17 @@ func (o *Options) Validate() error { | ||||||
| 	msgs = parseProviderInfo(o, msgs) | 	msgs = parseProviderInfo(o, msgs) | ||||||
| 
 | 
 | ||||||
| 	if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) { | 	if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) { | ||||||
| 		valid_cookie_secret_size := false | 		validCookieSecretSize := false | ||||||
| 		for _, i := range []int{16, 24, 32} { | 		for _, i := range []int{16, 24, 32} { | ||||||
| 			if len(secretBytes(o.CookieSecret)) == i { | 			if len(secretBytes(o.CookieSecret)) == i { | ||||||
| 				valid_cookie_secret_size = true | 				validCookieSecretSize = true | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		var decoded bool | 		var decoded bool | ||||||
| 		if string(secretBytes(o.CookieSecret)) != o.CookieSecret { | 		if string(secretBytes(o.CookieSecret)) != o.CookieSecret { | ||||||
| 			decoded = true | 			decoded = true | ||||||
| 		} | 		} | ||||||
| 		if valid_cookie_secret_size == false { | 		if validCookieSecretSize == false { | ||||||
| 			var suffix string | 			var suffix string | ||||||
| 			if decoded { | 			if decoded { | ||||||
| 				suffix = fmt.Sprintf(" note: cookie secret was base64 decoded from %q", o.CookieSecret) | 				suffix = fmt.Sprintf(" note: cookie secret was base64 decoded from %q", o.CookieSecret) | ||||||
|  | @ -294,12 +299,13 @@ func parseSignatureKey(o *Options, msgs []string) []string { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	algorithm, secretKey := components[0], components[1] | 	algorithm, secretKey := components[0], components[1] | ||||||
| 	if hash, err := hmacauth.DigestNameToCryptoHash(algorithm); err != nil { | 	var hash crypto.Hash | ||||||
|  | 	var err error | ||||||
|  | 	if hash, err = hmacauth.DigestNameToCryptoHash(algorithm); err != nil { | ||||||
| 		return append(msgs, "unsupported signature hash algorithm: "+ | 		return append(msgs, "unsupported signature hash algorithm: "+ | ||||||
| 			o.SignatureKey) | 			o.SignatureKey) | ||||||
| 	} else { |  | ||||||
| 		o.signatureData = &SignatureData{hash, secretKey} |  | ||||||
| 	} | 	} | ||||||
|  | 	o.signatureData = &SignatureData{hash, secretKey} | ||||||
| 	return msgs | 	return msgs | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -88,9 +88,9 @@ func TestProxyURLs(t *testing.T) { | ||||||
| 	o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8081") | 	o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8081") | ||||||
| 	assert.Equal(t, nil, o.Validate()) | 	assert.Equal(t, nil, o.Validate()) | ||||||
| 	expected := []*url.URL{ | 	expected := []*url.URL{ | ||||||
| 		&url.URL{Scheme: "http", Host: "127.0.0.1:8080", Path: "/"}, | 		{Scheme: "http", Host: "127.0.0.1:8080", Path: "/"}, | ||||||
| 		// note the '/' was added
 | 		// note the '/' was added
 | ||||||
| 		&url.URL{Scheme: "http", Host: "127.0.0.1:8081", Path: "/"}, | 		{Scheme: "http", Host: "127.0.0.1:8081", Path: "/"}, | ||||||
| 	} | 	} | ||||||
| 	assert.Equal(t, expected, o.proxyURLs) | 	assert.Equal(t, expected, o.proxyURLs) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -3,18 +3,21 @@ package providers | ||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/bitly/go-simplejson" |  | ||||||
| 	"github.com/bitly/oauth2_proxy/api" |  | ||||||
| 	"log" | 	"log" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
|  | 
 | ||||||
|  | 	"github.com/bitly/go-simplejson" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/api" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // AzureProvider represents an Azure based Identity Provider
 | ||||||
| type AzureProvider struct { | type AzureProvider struct { | ||||||
| 	*ProviderData | 	*ProviderData | ||||||
| 	Tenant string | 	Tenant string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // NewAzureProvider initiates a new AzureProvider
 | ||||||
| func NewAzureProvider(p *ProviderData) *AzureProvider { | func NewAzureProvider(p *ProviderData) *AzureProvider { | ||||||
| 	p.ProviderName = "Azure" | 	p.ProviderName = "Azure" | ||||||
| 
 | 
 | ||||||
|  | @ -39,6 +42,7 @@ func NewAzureProvider(p *ProviderData) *AzureProvider { | ||||||
| 	return &AzureProvider{ProviderData: p} | 	return &AzureProvider{ProviderData: p} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Configure defaults the AzureProvider configuration options
 | ||||||
| func (p *AzureProvider) Configure(tenant string) { | func (p *AzureProvider) Configure(tenant string) { | ||||||
| 	p.Tenant = tenant | 	p.Tenant = tenant | ||||||
| 	if tenant == "" { | 	if tenant == "" { | ||||||
|  | @ -60,9 +64,9 @@ func (p *AzureProvider) Configure(tenant string) { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getAzureHeader(access_token string) http.Header { | func getAzureHeader(accessToken string) http.Header { | ||||||
| 	header := make(http.Header) | 	header := make(http.Header) | ||||||
| 	header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) | 	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) | ||||||
| 	return header | 	return header | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -83,6 +87,7 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) { | ||||||
| 	return email, err | 	return email, err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // GetEmailAddress returns the Account email address
 | ||||||
| func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) { | func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||||
| 	var email string | 	var email string | ||||||
| 	var err error | 	var err error | ||||||
|  |  | ||||||
|  | @ -110,8 +110,7 @@ func testAzureBackend(payload string) *httptest.Server { | ||||||
| 
 | 
 | ||||||
| 	return httptest.NewServer(http.HandlerFunc( | 	return httptest.NewServer(http.HandlerFunc( | ||||||
| 		func(w http.ResponseWriter, r *http.Request) { | 		func(w http.ResponseWriter, r *http.Request) { | ||||||
| 			url := r.URL | 			if r.URL.Path != path || r.URL.RawQuery != query { | ||||||
| 			if url.Path != path || url.RawQuery != query { |  | ||||||
| 				w.WriteHeader(404) | 				w.WriteHeader(404) | ||||||
| 			} else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" { | 			} else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" { | ||||||
| 				w.WriteHeader(403) | 				w.WriteHeader(403) | ||||||
|  |  | ||||||
|  | @ -6,13 +6,15 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
| 	"github.com/bitly/oauth2_proxy/api" | 	"github.com/pusher/oauth2_proxy/api" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // FacebookProvider represents an Facebook based Identity Provider
 | ||||||
| type FacebookProvider struct { | type FacebookProvider struct { | ||||||
| 	*ProviderData | 	*ProviderData | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // NewFacebookProvider initiates a new FacebookProvider
 | ||||||
| func NewFacebookProvider(p *ProviderData) *FacebookProvider { | func NewFacebookProvider(p *ProviderData) *FacebookProvider { | ||||||
| 	p.ProviderName = "Facebook" | 	p.ProviderName = "Facebook" | ||||||
| 	if p.LoginURL.String() == "" { | 	if p.LoginURL.String() == "" { | ||||||
|  | @ -43,14 +45,15 @@ func NewFacebookProvider(p *ProviderData) *FacebookProvider { | ||||||
| 	return &FacebookProvider{ProviderData: p} | 	return &FacebookProvider{ProviderData: p} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getFacebookHeader(access_token string) http.Header { | func getFacebookHeader(accessToken string) http.Header { | ||||||
| 	header := make(http.Header) | 	header := make(http.Header) | ||||||
| 	header.Set("Accept", "application/json") | 	header.Set("Accept", "application/json") | ||||||
| 	header.Set("x-li-format", "json") | 	header.Set("x-li-format", "json") | ||||||
| 	header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) | 	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) | ||||||
| 	return header | 	return header | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // GetEmailAddress returns the Account email address
 | ||||||
| func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { | func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||||
| 	if s.AccessToken == "" { | 	if s.AccessToken == "" { | ||||||
| 		return "", errors.New("missing access token") | 		return "", errors.New("missing access token") | ||||||
|  | @ -65,7 +68,7 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||||
| 		Email string | 		Email string | ||||||
| 	} | 	} | ||||||
| 	var r result | 	var r result | ||||||
| 	err = api.RequestJson(req, &r) | 	err = api.RequestJSON(req, &r) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
|  | @ -75,6 +78,7 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||||
| 	return r.Email, nil | 	return r.Email, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ValidateSessionState validates the AccessToken
 | ||||||
| func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool { | func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool { | ||||||
| 	return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken)) | 	return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -12,12 +12,14 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // GitHubProvider represents an GitHub based Identity Provider
 | ||||||
| type GitHubProvider struct { | type GitHubProvider struct { | ||||||
| 	*ProviderData | 	*ProviderData | ||||||
| 	Org  string | 	Org  string | ||||||
| 	Team string | 	Team string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // NewGitHubProvider initiates a new GitHubProvider
 | ||||||
| func NewGitHubProvider(p *ProviderData) *GitHubProvider { | func NewGitHubProvider(p *ProviderData) *GitHubProvider { | ||||||
| 	p.ProviderName = "GitHub" | 	p.ProviderName = "GitHub" | ||||||
| 	if p.LoginURL == nil || p.LoginURL.String() == "" { | 	if p.LoginURL == nil || p.LoginURL.String() == "" { | ||||||
|  | @ -47,6 +49,8 @@ func NewGitHubProvider(p *ProviderData) *GitHubProvider { | ||||||
| 	} | 	} | ||||||
| 	return &GitHubProvider{ProviderData: p} | 	return &GitHubProvider{ProviderData: p} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // SetOrgTeam adds GitHub org reading parameters to the OAuth2 scope
 | ||||||
| func (p *GitHubProvider) SetOrgTeam(org, team string) { | func (p *GitHubProvider) SetOrgTeam(org, team string) { | ||||||
| 	p.Org = org | 	p.Org = org | ||||||
| 	p.Team = team | 	p.Team = team | ||||||
|  | @ -106,7 +110,7 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) { | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		orgs = append(orgs, op...) | 		orgs = append(orgs, op...) | ||||||
| 		pn += 1 | 		pn++ | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var presentOrgs []string | 	var presentOrgs []string | ||||||
|  | @ -186,7 +190,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { | ||||||
| 		log.Printf("Missing Team:%q from Org:%q in teams: %v", p.Team, p.Org, presentTeams) | 		log.Printf("Missing Team:%q from Org:%q in teams: %v", p.Team, p.Org, presentTeams) | ||||||
| 	} else { | 	} else { | ||||||
| 		var allOrgs []string | 		var allOrgs []string | ||||||
| 		for org, _ := range presentOrgs { | 		for org := range presentOrgs { | ||||||
| 			allOrgs = append(allOrgs, org) | 			allOrgs = append(allOrgs, org) | ||||||
| 		} | 		} | ||||||
| 		log.Printf("Missing Organization:%q in %#v", p.Org, allOrgs) | 		log.Printf("Missing Organization:%q in %#v", p.Org, allOrgs) | ||||||
|  | @ -194,6 +198,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { | ||||||
| 	return false, nil | 	return false, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // GetEmailAddress returns the Account email address
 | ||||||
| func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { | func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||||
| 
 | 
 | ||||||
| 	var emails []struct { | 	var emails []struct { | ||||||
|  | @ -251,6 +256,7 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||||
| 	return "", nil | 	return "", nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // GetUserName returns the Account user name
 | ||||||
| func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) { | func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) { | ||||||
| 	var user struct { | 	var user struct { | ||||||
| 		Login string `json:"login"` | 		Login string `json:"login"` | ||||||
|  |  | ||||||
|  | @ -29,19 +29,18 @@ func testGitHubProvider(hostname string) *GitHubProvider { | ||||||
| 
 | 
 | ||||||
| func testGitHubBackend(payload []string) *httptest.Server { | func testGitHubBackend(payload []string) *httptest.Server { | ||||||
| 	pathToQueryMap := map[string][]string{ | 	pathToQueryMap := map[string][]string{ | ||||||
| 		"/user":        []string{""}, | 		"/user":        {""}, | ||||||
| 		"/user/emails": []string{""}, | 		"/user/emails": {""}, | ||||||
| 		"/user/orgs":   []string{"limit=200&page=1", "limit=200&page=2", "limit=200&page=3"}, | 		"/user/orgs":   {"limit=200&page=1", "limit=200&page=2", "limit=200&page=3"}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return httptest.NewServer(http.HandlerFunc( | 	return httptest.NewServer(http.HandlerFunc( | ||||||
| 		func(w http.ResponseWriter, r *http.Request) { | 		func(w http.ResponseWriter, r *http.Request) { | ||||||
| 			url := r.URL | 			query, ok := pathToQueryMap[r.URL.Path] | ||||||
| 			query, ok := pathToQueryMap[url.Path] |  | ||||||
| 			validQuery := false | 			validQuery := false | ||||||
| 			index := 0 | 			index := 0 | ||||||
| 			for i, q := range query { | 			for i, q := range query { | ||||||
| 				if q == url.RawQuery { | 				if q == r.URL.RawQuery { | ||||||
| 					validQuery = true | 					validQuery = true | ||||||
| 					index = i | 					index = i | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
|  | @ -5,13 +5,15 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
| 	"github.com/bitly/oauth2_proxy/api" | 	"github.com/pusher/oauth2_proxy/api" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // GitLabProvider represents an GitLab based Identity Provider
 | ||||||
| type GitLabProvider struct { | type GitLabProvider struct { | ||||||
| 	*ProviderData | 	*ProviderData | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // NewGitLabProvider initiates a new GitLabProvider
 | ||||||
| func NewGitLabProvider(p *ProviderData) *GitLabProvider { | func NewGitLabProvider(p *ProviderData) *GitLabProvider { | ||||||
| 	p.ProviderName = "GitLab" | 	p.ProviderName = "GitLab" | ||||||
| 	if p.LoginURL == nil || p.LoginURL.String() == "" { | 	if p.LoginURL == nil || p.LoginURL.String() == "" { | ||||||
|  | @ -41,6 +43,7 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider { | ||||||
| 	return &GitLabProvider{ProviderData: p} | 	return &GitLabProvider{ProviderData: p} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // GetEmailAddress returns the Account email address
 | ||||||
| func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) { | func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||||
| 
 | 
 | ||||||
| 	req, err := http.NewRequest("GET", | 	req, err := http.NewRequest("GET", | ||||||
|  |  | ||||||
|  | @ -33,8 +33,7 @@ func testGitLabBackend(payload string) *httptest.Server { | ||||||
| 
 | 
 | ||||||
| 	return httptest.NewServer(http.HandlerFunc( | 	return httptest.NewServer(http.HandlerFunc( | ||||||
| 		func(w http.ResponseWriter, r *http.Request) { | 		func(w http.ResponseWriter, r *http.Request) { | ||||||
| 			url := r.URL | 			if r.URL.Path != path || r.URL.RawQuery != query { | ||||||
| 			if url.Path != path || url.RawQuery != query { |  | ||||||
| 				w.WriteHeader(404) | 				w.WriteHeader(404) | ||||||
| 			} else { | 			} else { | ||||||
| 				w.WriteHeader(200) | 				w.WriteHeader(200) | ||||||
|  | @ -87,8 +86,8 @@ func TestGitLabProviderGetEmailAddress(t *testing.T) { | ||||||
| 	b := testGitLabBackend("{\"email\": \"michael.bland@gsa.gov\"}") | 	b := testGitLabBackend("{\"email\": \"michael.bland@gsa.gov\"}") | ||||||
| 	defer b.Close() | 	defer b.Close() | ||||||
| 
 | 
 | ||||||
| 	b_url, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testGitLabProvider(b_url.Host) | 	p := testGitLabProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
|  | @ -102,8 +101,8 @@ func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) { | ||||||
| 	b := testGitLabBackend("unused payload") | 	b := testGitLabBackend("unused payload") | ||||||
| 	defer b.Close() | 	defer b.Close() | ||||||
| 
 | 
 | ||||||
| 	b_url, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testGitLabProvider(b_url.Host) | 	p := testGitLabProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	// We'll trigger a request failure by using an unexpected access
 | 	// We'll trigger a request failure by using an unexpected access
 | ||||||
| 	// token. Alternatively, we could allow the parsing of the payload as
 | 	// token. Alternatively, we could allow the parsing of the payload as
 | ||||||
|  | @ -118,8 +117,8 @@ func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { | ||||||
| 	b := testGitLabBackend("{\"foo\": \"bar\"}") | 	b := testGitLabBackend("{\"foo\": \"bar\"}") | ||||||
| 	defer b.Close() | 	defer b.Close() | ||||||
| 
 | 
 | ||||||
| 	b_url, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testGitLabProvider(b_url.Host) | 	p := testGitLabProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
|  |  | ||||||
|  | @ -20,6 +20,7 @@ import ( | ||||||
| 	"google.golang.org/api/googleapi" | 	"google.golang.org/api/googleapi" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // GoogleProvider represents an Google based Identity Provider
 | ||||||
| type GoogleProvider struct { | type GoogleProvider struct { | ||||||
| 	*ProviderData | 	*ProviderData | ||||||
| 	RedeemRefreshURL *url.URL | 	RedeemRefreshURL *url.URL | ||||||
|  | @ -28,6 +29,7 @@ type GoogleProvider struct { | ||||||
| 	GroupValidator func(string) bool | 	GroupValidator func(string) bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // NewGoogleProvider initiates a new GoogleProvider
 | ||||||
| func NewGoogleProvider(p *ProviderData) *GoogleProvider { | func NewGoogleProvider(p *ProviderData) *GoogleProvider { | ||||||
| 	p.ProviderName = "Google" | 	p.ProviderName = "Google" | ||||||
| 	if p.LoginURL.String() == "" { | 	if p.LoginURL.String() == "" { | ||||||
|  | @ -62,7 +64,7 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func emailFromIdToken(idToken string) (string, error) { | func emailFromIDToken(idToken string) (string, error) { | ||||||
| 
 | 
 | ||||||
| 	// id_token is a base64 encode ID token payload
 | 	// id_token is a base64 encode ID token payload
 | ||||||
| 	// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
 | 	// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
 | ||||||
|  | @ -90,6 +92,7 @@ func emailFromIdToken(idToken string) (string, error) { | ||||||
| 	return email.Email, nil | 	return email.Email, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||||
| func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { | func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		err = errors.New("missing code") | 		err = errors.New("missing code") | ||||||
|  | @ -129,14 +132,14 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err | ||||||
| 		AccessToken  string `json:"access_token"` | 		AccessToken  string `json:"access_token"` | ||||||
| 		RefreshToken string `json:"refresh_token"` | 		RefreshToken string `json:"refresh_token"` | ||||||
| 		ExpiresIn    int64  `json:"expires_in"` | 		ExpiresIn    int64  `json:"expires_in"` | ||||||
| 		IdToken      string `json:"id_token"` | 		IDToken      string `json:"id_token"` | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(body, &jsonResponse) | 	err = json.Unmarshal(body, &jsonResponse) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	var email string | 	var email string | ||||||
| 	email, err = emailFromIdToken(jsonResponse.IdToken) | 	email, err = emailFromIDToken(jsonResponse.IDToken) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  | @ -249,6 +252,8 @@ func (p *GoogleProvider) ValidateGroup(email string) bool { | ||||||
| 	return p.GroupValidator(email) | 	return p.GroupValidator(email) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // RefreshSessionIfNeeded checks if the session has expired and uses the
 | ||||||
|  | // RefreshToken to fetch a new ID token if required
 | ||||||
| func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | ||||||
| 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { | 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { | ||||||
| 		return false, nil | 		return false, nil | ||||||
|  |  | ||||||
|  | @ -81,7 +81,7 @@ type redeemResponse struct { | ||||||
| 	AccessToken  string `json:"access_token"` | 	AccessToken  string `json:"access_token"` | ||||||
| 	RefreshToken string `json:"refresh_token"` | 	RefreshToken string `json:"refresh_token"` | ||||||
| 	ExpiresIn    int64  `json:"expires_in"` | 	ExpiresIn    int64  `json:"expires_in"` | ||||||
| 	IdToken      string `json:"id_token"` | 	IDToken      string `json:"id_token"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestGoogleProviderGetEmailAddress(t *testing.T) { | func TestGoogleProviderGetEmailAddress(t *testing.T) { | ||||||
|  | @ -90,7 +90,7 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) { | ||||||
| 		AccessToken:  "a1234", | 		AccessToken:  "a1234", | ||||||
| 		ExpiresIn:    10, | 		ExpiresIn:    10, | ||||||
| 		RefreshToken: "refresh12345", | 		RefreshToken: "refresh12345", | ||||||
| 		IdToken:      "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)), | 		IDToken:      "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)), | ||||||
| 	}) | 	}) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	var server *httptest.Server | 	var server *httptest.Server | ||||||
|  | @ -127,7 +127,7 @@ func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) { | ||||||
| 	p := newGoogleProvider() | 	p := newGoogleProvider() | ||||||
| 	body, err := json.Marshal(redeemResponse{ | 	body, err := json.Marshal(redeemResponse{ | ||||||
| 		AccessToken: "a1234", | 		AccessToken: "a1234", | ||||||
| 		IdToken:     "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`, | 		IDToken:     "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`, | ||||||
| 	}) | 	}) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	var server *httptest.Server | 	var server *httptest.Server | ||||||
|  | @ -146,7 +146,7 @@ func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	body, err := json.Marshal(redeemResponse{ | 	body, err := json.Marshal(redeemResponse{ | ||||||
| 		AccessToken: "a1234", | 		AccessToken: "a1234", | ||||||
| 		IdToken:     "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)), | 		IDToken:     "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)), | ||||||
| 	}) | 	}) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	var server *httptest.Server | 	var server *httptest.Server | ||||||
|  | @ -165,7 +165,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { | ||||||
| 	p := newGoogleProvider() | 	p := newGoogleProvider() | ||||||
| 	body, err := json.Marshal(redeemResponse{ | 	body, err := json.Marshal(redeemResponse{ | ||||||
| 		AccessToken: "a1234", | 		AccessToken: "a1234", | ||||||
| 		IdToken:     "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)), | 		IDToken:     "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)), | ||||||
| 	}) | 	}) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	var server *httptest.Server | 	var server *httptest.Server | ||||||
|  |  | ||||||
|  | @ -6,7 +6,7 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
| 	"github.com/bitly/oauth2_proxy/api" | 	"github.com/pusher/oauth2_proxy/api" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // stripToken is a helper function to obfuscate "access_token"
 | // stripToken is a helper function to obfuscate "access_token"
 | ||||||
|  | @ -46,13 +46,13 @@ func stripParam(param, endpoint string) string { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // validateToken returns true if token is valid
 | // validateToken returns true if token is valid
 | ||||||
| func validateToken(p Provider, access_token string, header http.Header) bool { | func validateToken(p Provider, accessToken string, header http.Header) bool { | ||||||
| 	if access_token == "" || p.Data().ValidateURL == nil { | 	if accessToken == "" || p.Data().ValidateURL == nil { | ||||||
| 		return false | 		return false | ||||||
| 	} | 	} | ||||||
| 	endpoint := p.Data().ValidateURL.String() | 	endpoint := p.Data().ValidateURL.String() | ||||||
| 	if len(header) == 0 { | 	if len(header) == 0 { | ||||||
| 		params := url.Values{"access_token": {access_token}} | 		params := url.Values{"access_token": {accessToken}} | ||||||
| 		endpoint = endpoint + "?" + params.Encode() | 		endpoint = endpoint + "?" + params.Encode() | ||||||
| 	} | 	} | ||||||
| 	resp, err := api.RequestUnparsedResponse(endpoint, header) | 	resp, err := api.RequestUnparsedResponse(endpoint, header) | ||||||
|  | @ -72,8 +72,3 @@ func validateToken(p Provider, access_token string, header http.Header) bool { | ||||||
| 	log.Printf("token validation request failed: status %d - %s", resp.StatusCode, body) | 	log.Printf("token validation request failed: status %d - %s", resp.StatusCode, body) | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func updateURL(url *url.URL, hostname string) { |  | ||||||
| 	url.Scheme = "http" |  | ||||||
| 	url.Host = hostname |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -10,6 +10,11 @@ import ( | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | func updateURL(url *url.URL, hostname string) { | ||||||
|  | 	url.Scheme = "http" | ||||||
|  | 	url.Host = hostname | ||||||
|  | } | ||||||
|  | 
 | ||||||
| type ValidateSessionStateTestProvider struct { | type ValidateSessionStateTestProvider struct { | ||||||
| 	*ProviderData | 	*ProviderData | ||||||
| } | } | ||||||
|  | @ -25,28 +30,28 @@ func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type ValidateSessionStateTest struct { | type ValidateSessionStateTest struct { | ||||||
| 	backend       *httptest.Server | 	backend      *httptest.Server | ||||||
| 	response_code int | 	responseCode int | ||||||
| 	provider      *ValidateSessionStateTestProvider | 	provider     *ValidateSessionStateTestProvider | ||||||
| 	header        http.Header | 	header       http.Header | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewValidateSessionStateTest() *ValidateSessionStateTest { | func NewValidateSessionStateTest() *ValidateSessionStateTest { | ||||||
| 	var vt_test ValidateSessionStateTest | 	var vtTest ValidateSessionStateTest | ||||||
| 
 | 
 | ||||||
| 	vt_test.backend = httptest.NewServer( | 	vtTest.backend = httptest.NewServer( | ||||||
| 		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
| 			if r.URL.Path != "/oauth/tokeninfo" { | 			if r.URL.Path != "/oauth/tokeninfo" { | ||||||
| 				w.WriteHeader(500) | 				w.WriteHeader(500) | ||||||
| 				w.Write([]byte("unknown URL")) | 				w.Write([]byte("unknown URL")) | ||||||
| 			} | 			} | ||||||
| 			token_param := r.FormValue("access_token") | 			tokenParam := r.FormValue("access_token") | ||||||
| 			if token_param == "" { | 			if tokenParam == "" { | ||||||
| 				missing := false | 				missing := false | ||||||
| 				received_headers := r.Header | 				receivedHeaders := r.Header | ||||||
| 				for k, _ := range vt_test.header { | 				for k := range vtTest.header { | ||||||
| 					received := received_headers.Get(k) | 					received := receivedHeaders.Get(k) | ||||||
| 					expected := vt_test.header.Get(k) | 					expected := vtTest.header.Get(k) | ||||||
| 					if received == "" || received != expected { | 					if received == "" || received != expected { | ||||||
| 						missing = true | 						missing = true | ||||||
| 					} | 					} | ||||||
|  | @ -56,68 +61,68 @@ func NewValidateSessionStateTest() *ValidateSessionStateTest { | ||||||
| 					w.Write([]byte("no token param and missing or incorrect headers")) | 					w.Write([]byte("no token param and missing or incorrect headers")) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 			w.WriteHeader(vt_test.response_code) | 			w.WriteHeader(vtTest.responseCode) | ||||||
| 			w.Write([]byte("only code matters; contents disregarded")) | 			w.Write([]byte("only code matters; contents disregarded")) | ||||||
| 
 | 
 | ||||||
| 		})) | 		})) | ||||||
| 	backend_url, _ := url.Parse(vt_test.backend.URL) | 	backendURL, _ := url.Parse(vtTest.backend.URL) | ||||||
| 	vt_test.provider = &ValidateSessionStateTestProvider{ | 	vtTest.provider = &ValidateSessionStateTestProvider{ | ||||||
| 		ProviderData: &ProviderData{ | 		ProviderData: &ProviderData{ | ||||||
| 			ValidateURL: &url.URL{ | 			ValidateURL: &url.URL{ | ||||||
| 				Scheme: "http", | 				Scheme: "http", | ||||||
| 				Host:   backend_url.Host, | 				Host:   backendURL.Host, | ||||||
| 				Path:   "/oauth/tokeninfo", | 				Path:   "/oauth/tokeninfo", | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 	vt_test.response_code = 200 | 	vtTest.responseCode = 200 | ||||||
| 	return &vt_test | 	return &vtTest | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (vt_test *ValidateSessionStateTest) Close() { | func (vtTest *ValidateSessionStateTest) Close() { | ||||||
| 	vt_test.backend.Close() | 	vtTest.backend.Close() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestValidateSessionStateValidToken(t *testing.T) { | func TestValidateSessionStateValidToken(t *testing.T) { | ||||||
| 	vt_test := NewValidateSessionStateTest() | 	vtTest := NewValidateSessionStateTest() | ||||||
| 	defer vt_test.Close() | 	defer vtTest.Close() | ||||||
| 	assert.Equal(t, true, validateToken(vt_test.provider, "foobar", nil)) | 	assert.Equal(t, true, validateToken(vtTest.provider, "foobar", nil)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) { | func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) { | ||||||
| 	vt_test := NewValidateSessionStateTest() | 	vtTest := NewValidateSessionStateTest() | ||||||
| 	defer vt_test.Close() | 	defer vtTest.Close() | ||||||
| 	vt_test.header = make(http.Header) | 	vtTest.header = make(http.Header) | ||||||
| 	vt_test.header.Set("Authorization", "Bearer foobar") | 	vtTest.header.Set("Authorization", "Bearer foobar") | ||||||
| 	assert.Equal(t, true, | 	assert.Equal(t, true, | ||||||
| 		validateToken(vt_test.provider, "foobar", vt_test.header)) | 		validateToken(vtTest.provider, "foobar", vtTest.header)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestValidateSessionStateEmptyToken(t *testing.T) { | func TestValidateSessionStateEmptyToken(t *testing.T) { | ||||||
| 	vt_test := NewValidateSessionStateTest() | 	vtTest := NewValidateSessionStateTest() | ||||||
| 	defer vt_test.Close() | 	defer vtTest.Close() | ||||||
| 	assert.Equal(t, false, validateToken(vt_test.provider, "", nil)) | 	assert.Equal(t, false, validateToken(vtTest.provider, "", nil)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestValidateSessionStateEmptyValidateURL(t *testing.T) { | func TestValidateSessionStateEmptyValidateURL(t *testing.T) { | ||||||
| 	vt_test := NewValidateSessionStateTest() | 	vtTest := NewValidateSessionStateTest() | ||||||
| 	defer vt_test.Close() | 	defer vtTest.Close() | ||||||
| 	vt_test.provider.Data().ValidateURL = nil | 	vtTest.provider.Data().ValidateURL = nil | ||||||
| 	assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) | 	assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestValidateSessionStateRequestNetworkFailure(t *testing.T) { | func TestValidateSessionStateRequestNetworkFailure(t *testing.T) { | ||||||
| 	vt_test := NewValidateSessionStateTest() | 	vtTest := NewValidateSessionStateTest() | ||||||
| 	// Close immediately to simulate a network failure
 | 	// Close immediately to simulate a network failure
 | ||||||
| 	vt_test.Close() | 	vtTest.Close() | ||||||
| 	assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) | 	assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestValidateSessionStateExpiredToken(t *testing.T) { | func TestValidateSessionStateExpiredToken(t *testing.T) { | ||||||
| 	vt_test := NewValidateSessionStateTest() | 	vtTest := NewValidateSessionStateTest() | ||||||
| 	defer vt_test.Close() | 	defer vtTest.Close() | ||||||
| 	vt_test.response_code = 401 | 	vtTest.responseCode = 401 | ||||||
| 	assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) | 	assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestStripTokenNotPresent(t *testing.T) { | func TestStripTokenNotPresent(t *testing.T) { | ||||||
|  |  | ||||||
|  | @ -6,13 +6,15 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
| 	"github.com/bitly/oauth2_proxy/api" | 	"github.com/pusher/oauth2_proxy/api" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // LinkedInProvider represents an LinkedIn based Identity Provider
 | ||||||
| type LinkedInProvider struct { | type LinkedInProvider struct { | ||||||
| 	*ProviderData | 	*ProviderData | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // NewLinkedInProvider initiates a new LinkedInProvider
 | ||||||
| func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { | func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { | ||||||
| 	p.ProviderName = "LinkedIn" | 	p.ProviderName = "LinkedIn" | ||||||
| 	if p.LoginURL.String() == "" { | 	if p.LoginURL.String() == "" { | ||||||
|  | @ -39,14 +41,15 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { | ||||||
| 	return &LinkedInProvider{ProviderData: p} | 	return &LinkedInProvider{ProviderData: p} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getLinkedInHeader(access_token string) http.Header { | func getLinkedInHeader(accessToken string) http.Header { | ||||||
| 	header := make(http.Header) | 	header := make(http.Header) | ||||||
| 	header.Set("Accept", "application/json") | 	header.Set("Accept", "application/json") | ||||||
| 	header.Set("x-li-format", "json") | 	header.Set("x-li-format", "json") | ||||||
| 	header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) | 	header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) | ||||||
| 	return header | 	return header | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // GetEmailAddress returns the Account email address
 | ||||||
| func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { | func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||||
| 	if s.AccessToken == "" { | 	if s.AccessToken == "" { | ||||||
| 		return "", errors.New("missing access token") | 		return "", errors.New("missing access token") | ||||||
|  | @ -69,6 +72,7 @@ func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||||
| 	return email, nil | 	return email, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ValidateSessionState validates the AccessToken
 | ||||||
| func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool { | func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool { | ||||||
| 	return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken)) | 	return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -31,8 +31,7 @@ func testLinkedInBackend(payload string) *httptest.Server { | ||||||
| 
 | 
 | ||||||
| 	return httptest.NewServer(http.HandlerFunc( | 	return httptest.NewServer(http.HandlerFunc( | ||||||
| 		func(w http.ResponseWriter, r *http.Request) { | 		func(w http.ResponseWriter, r *http.Request) { | ||||||
| 			url := r.URL | 			if r.URL.Path != path { | ||||||
| 			if url.Path != path { |  | ||||||
| 				w.WriteHeader(404) | 				w.WriteHeader(404) | ||||||
| 			} else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" { | 			} else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" { | ||||||
| 				w.WriteHeader(403) | 				w.WriteHeader(403) | ||||||
|  | @ -95,8 +94,8 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) { | ||||||
| 	b := testLinkedInBackend(`"user@linkedin.com"`) | 	b := testLinkedInBackend(`"user@linkedin.com"`) | ||||||
| 	defer b.Close() | 	defer b.Close() | ||||||
| 
 | 
 | ||||||
| 	b_url, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testLinkedInProvider(b_url.Host) | 	p := testLinkedInProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
|  | @ -108,8 +107,8 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) { | ||||||
| 	b := testLinkedInBackend("unused payload") | 	b := testLinkedInBackend("unused payload") | ||||||
| 	defer b.Close() | 	defer b.Close() | ||||||
| 
 | 
 | ||||||
| 	b_url, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testLinkedInProvider(b_url.Host) | 	p := testLinkedInProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	// We'll trigger a request failure by using an unexpected access
 | 	// We'll trigger a request failure by using an unexpected access
 | ||||||
| 	// token. Alternatively, we could allow the parsing of the payload as
 | 	// token. Alternatively, we could allow the parsing of the payload as
 | ||||||
|  | @ -124,8 +123,8 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { | ||||||
| 	b := testLinkedInBackend("{\"foo\": \"bar\"}") | 	b := testLinkedInBackend("{\"foo\": \"bar\"}") | ||||||
| 	defer b.Close() | 	defer b.Close() | ||||||
| 
 | 
 | ||||||
| 	b_url, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testLinkedInProvider(b_url.Host) | 	p := testLinkedInProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
|  |  | ||||||
|  | @ -10,17 +10,20 @@ import ( | ||||||
| 	oidc "github.com/coreos/go-oidc" | 	oidc "github.com/coreos/go-oidc" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // OIDCProvider represents an OIDC based Identity Provider
 | ||||||
| type OIDCProvider struct { | type OIDCProvider struct { | ||||||
| 	*ProviderData | 	*ProviderData | ||||||
| 
 | 
 | ||||||
| 	Verifier *oidc.IDTokenVerifier | 	Verifier *oidc.IDTokenVerifier | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // NewOIDCProvider initiates a new OIDCProvider
 | ||||||
| func NewOIDCProvider(p *ProviderData) *OIDCProvider { | func NewOIDCProvider(p *ProviderData) *OIDCProvider { | ||||||
| 	p.ProviderName = "OpenID Connect" | 	p.ProviderName = "OpenID Connect" | ||||||
| 	return &OIDCProvider{ProviderData: p} | 	return &OIDCProvider{ProviderData: p} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||||
| func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { | func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { | ||||||
| 	ctx := context.Background() | 	ctx := context.Background() | ||||||
| 	c := oauth2.Config{ | 	c := oauth2.Config{ | ||||||
|  | @ -73,6 +76,11 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // RefreshSessionIfNeeded checks if the session has expired and uses the
 | ||||||
|  | // RefreshToken to fetch a new ID token if required
 | ||||||
|  | //
 | ||||||
|  | // WARNGING: This implementation is broken and does not check with the upstream
 | ||||||
|  | // OIDC provider before refreshing the session
 | ||||||
| func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | ||||||
| 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { | 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { | ||||||
| 		return false, nil | 		return false, nil | ||||||
|  |  | ||||||
|  | @ -4,6 +4,8 @@ import ( | ||||||
| 	"net/url" | 	"net/url" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // ProviderData contains information required to configure all implementations
 | ||||||
|  | // of OAuth2 providers
 | ||||||
| type ProviderData struct { | type ProviderData struct { | ||||||
| 	ProviderName      string | 	ProviderName      string | ||||||
| 	ClientID          string | 	ClientID          string | ||||||
|  | @ -17,4 +19,5 @@ type ProviderData struct { | ||||||
| 	ApprovalPrompt    string | 	ApprovalPrompt    string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Data returns the ProviderData
 | ||||||
| func (p *ProviderData) Data() *ProviderData { return p } | func (p *ProviderData) Data() *ProviderData { return p } | ||||||
|  |  | ||||||
|  | @ -9,9 +9,10 @@ import ( | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
| 	"github.com/bitly/oauth2_proxy/cookie" | 	"github.com/pusher/oauth2_proxy/cookie" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // Redeem provides a default implementation of the OAuth2 token redemption process
 | ||||||
| func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) { | func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) { | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		err = errors.New("missing code") | 		err = errors.New("missing code") | ||||||
|  | @ -102,6 +103,7 @@ func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *Session | ||||||
| 	return DecodeSessionState(v, c) | 	return DecodeSessionState(v, c) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // GetEmailAddress returns the Account email address
 | ||||||
| func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { | func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { | ||||||
| 	return "", errors.New("not implemented") | 	return "", errors.New("not implemented") | ||||||
| } | } | ||||||
|  | @ -117,11 +119,13 @@ func (p *ProviderData) ValidateGroup(email string) bool { | ||||||
| 	return true | 	return true | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ValidateSessionState validates the AccessToken
 | ||||||
| func (p *ProviderData) ValidateSessionState(s *SessionState) bool { | func (p *ProviderData) ValidateSessionState(s *SessionState) bool { | ||||||
| 	return validateToken(p, s.AccessToken, nil) | 	return validateToken(p, s.AccessToken, nil) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RefreshSessionIfNeeded
 | // RefreshSessionIfNeeded should refresh the user's session if required and
 | ||||||
|  | // do nothing if a refresh is not required
 | ||||||
| func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | ||||||
| 	return false, nil | 	return false, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,9 +1,10 @@ | ||||||
| package providers | package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"github.com/bitly/oauth2_proxy/cookie" | 	"github.com/pusher/oauth2_proxy/cookie" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // Provider represents an upstream identity provider implementation
 | ||||||
| type Provider interface { | type Provider interface { | ||||||
| 	Data() *ProviderData | 	Data() *ProviderData | ||||||
| 	GetEmailAddress(*SessionState) (string, error) | 	GetEmailAddress(*SessionState) (string, error) | ||||||
|  | @ -17,6 +18,7 @@ type Provider interface { | ||||||
| 	CookieForSession(*SessionState, *cookie.Cipher) (string, error) | 	CookieForSession(*SessionState, *cookie.Cipher) (string, error) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // New provides a new Provider based on the configured provider string
 | ||||||
| func New(provider string, p *ProviderData) Provider { | func New(provider string, p *ProviderData) Provider { | ||||||
| 	switch provider { | 	switch provider { | ||||||
| 	case "linkedin": | 	case "linkedin": | ||||||
|  |  | ||||||
|  | @ -6,9 +6,10 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/bitly/oauth2_proxy/cookie" | 	"github.com/pusher/oauth2_proxy/cookie" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // SessionState is used to store information about the currently authenticated user session
 | ||||||
| type SessionState struct { | type SessionState struct { | ||||||
| 	AccessToken  string | 	AccessToken  string | ||||||
| 	ExpiresOn    time.Time | 	ExpiresOn    time.Time | ||||||
|  | @ -17,6 +18,7 @@ type SessionState struct { | ||||||
| 	User         string | 	User         string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // IsExpired checks whether the session has expired
 | ||||||
| func (s *SessionState) IsExpired() bool { | func (s *SessionState) IsExpired() bool { | ||||||
| 	if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) { | 	if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) { | ||||||
| 		return true | 		return true | ||||||
|  | @ -24,6 +26,7 @@ func (s *SessionState) IsExpired() bool { | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // String constructs a summary of the session state
 | ||||||
| func (s *SessionState) String() string { | func (s *SessionState) String() string { | ||||||
| 	o := fmt.Sprintf("Session{%s", s.accountInfo()) | 	o := fmt.Sprintf("Session{%s", s.accountInfo()) | ||||||
| 	if s.AccessToken != "" { | 	if s.AccessToken != "" { | ||||||
|  | @ -38,6 +41,7 @@ func (s *SessionState) String() string { | ||||||
| 	return o + "}" | 	return o + "}" | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // EncodeSessionState returns string representation of the current session
 | ||||||
| func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { | func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { | ||||||
| 	if c == nil || s.AccessToken == "" { | 	if c == nil || s.AccessToken == "" { | ||||||
| 		return s.accountInfo(), nil | 		return s.accountInfo(), nil | ||||||
|  | @ -49,6 +53,7 @@ func (s *SessionState) accountInfo() string { | ||||||
| 	return fmt.Sprintf("email:%s user:%s", s.Email, s.User) | 	return fmt.Sprintf("email:%s user:%s", s.Email, s.User) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // EncryptedString encrypts the session state into a cookie string
 | ||||||
| func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { | func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { | ||||||
| 	var err error | 	var err error | ||||||
| 	if c == nil { | 	if c == nil { | ||||||
|  | @ -84,6 +89,7 @@ func decodeSessionStatePlain(v string) (s *SessionState, err error) { | ||||||
| 	return &SessionState{User: user, Email: email}, nil | 	return &SessionState{User: user, Email: email}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // DecodeSessionState decodes the session cookie string into a SessionState
 | ||||||
| func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) { | func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) { | ||||||
| 	if c == nil { | 	if c == nil { | ||||||
| 		return decodeSessionStatePlain(v) | 		return decodeSessionStatePlain(v) | ||||||
|  |  | ||||||
|  | @ -6,7 +6,7 @@ import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/bitly/oauth2_proxy/cookie" | 	"github.com/pusher/oauth2_proxy/cookie" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -4,13 +4,16 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // StringArray is a type alias for a slice of strings
 | ||||||
| type StringArray []string | type StringArray []string | ||||||
| 
 | 
 | ||||||
|  | // Set appends a string to the StringArray
 | ||||||
| func (a *StringArray) Set(s string) error { | func (a *StringArray) Set(s string) error { | ||||||
| 	*a = append(*a, s) | 	*a = append(*a, s) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // String joins elements of the StringArray into a single comma separated string
 | ||||||
| func (a *StringArray) String() string { | func (a *StringArray) String() string { | ||||||
| 	return strings.Join(*a, ",") | 	return strings.Join(*a, ",") | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -142,7 +142,7 @@ func getTemplates() *template.Template { | ||||||
| 	<footer> | 	<footer> | ||||||
| 	{{ if eq .Footer "-" }} | 	{{ if eq .Footer "-" }} | ||||||
| 	{{ else if eq .Footer ""}} | 	{{ else if eq .Footer ""}} | ||||||
| 	Secured with <a href="https://github.com/bitly/oauth2_proxy#oauth2_proxy">OAuth2 Proxy</a> version {{.Version}} | 	Secured with <a href="https://github.com/pusher/oauth2_proxy#oauth2_proxy">OAuth2 Proxy</a> version {{.Version}} | ||||||
| 	{{ else }} | 	{{ else }} | ||||||
| 	{{.Footer}} | 	{{.Footer}} | ||||||
| 	{{ end }} | 	{{ end }} | ||||||
|  |  | ||||||
							
								
								
									
										14
									
								
								test.sh
								
								
								
								
							
							
						
						
									
										14
									
								
								test.sh
								
								
								
								
							|  | @ -1,14 +0,0 @@ | ||||||
| #!/bin/bash |  | ||||||
| EXIT_CODE=0 |  | ||||||
| echo "gofmt" |  | ||||||
| diff -u <(echo -n) <(gofmt -d $(find . -type f -name '*.go' -not -path "./vendor/*")) || EXIT_CODE=1 |  | ||||||
| for pkg in $(go list ./... | grep -v '/vendor/' ); do |  | ||||||
|     echo "testing $pkg" |  | ||||||
|     echo "go vet $pkg" |  | ||||||
|     go vet "$pkg" || EXIT_CODE=1 |  | ||||||
|     echo "go test -v $pkg" |  | ||||||
|     go test -v -timeout 90s "$pkg" || EXIT_CODE=1 |  | ||||||
|     echo "go test -v -race $pkg" |  | ||||||
|     GOMAXPROCS=4 go test -v -timeout 90s0s -race "$pkg" || EXIT_CODE=1 |  | ||||||
| done |  | ||||||
| exit $EXIT_CODE |  | ||||||
							
								
								
									
										16
									
								
								validator.go
								
								
								
								
							
							
						
						
									
										16
									
								
								validator.go
								
								
								
								
							|  | @ -10,11 +10,13 @@ import ( | ||||||
| 	"unsafe" | 	"unsafe" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // UserMap holds information from the authenticated emails file
 | ||||||
| type UserMap struct { | type UserMap struct { | ||||||
| 	usersFile string | 	usersFile string | ||||||
| 	m         unsafe.Pointer | 	m         unsafe.Pointer | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // NewUserMap parses the authenticated emails file into a new UserMap
 | ||||||
| func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap { | func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap { | ||||||
| 	um := &UserMap{usersFile: usersFile} | 	um := &UserMap{usersFile: usersFile} | ||||||
| 	m := make(map[string]bool) | 	m := make(map[string]bool) | ||||||
|  | @ -30,23 +32,26 @@ func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap { | ||||||
| 	return um | 	return um | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // IsValid checks if an email is allowed
 | ||||||
| func (um *UserMap) IsValid(email string) (result bool) { | func (um *UserMap) IsValid(email string) (result bool) { | ||||||
| 	m := *(*map[string]bool)(atomic.LoadPointer(&um.m)) | 	m := *(*map[string]bool)(atomic.LoadPointer(&um.m)) | ||||||
| 	_, result = m[email] | 	_, result = m[email] | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // LoadAuthenticatedEmailsFile loads the authenticated emails file from disk
 | ||||||
|  | // and parses the contents as CSV
 | ||||||
| func (um *UserMap) LoadAuthenticatedEmailsFile() { | func (um *UserMap) LoadAuthenticatedEmailsFile() { | ||||||
| 	r, err := os.Open(um.usersFile) | 	r, err := os.Open(um.usersFile) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err) | 		log.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err) | ||||||
| 	} | 	} | ||||||
| 	defer r.Close() | 	defer r.Close() | ||||||
| 	csv_reader := csv.NewReader(r) | 	csvReader := csv.NewReader(r) | ||||||
| 	csv_reader.Comma = ',' | 	csvReader.Comma = ',' | ||||||
| 	csv_reader.Comment = '#' | 	csvReader.Comment = '#' | ||||||
| 	csv_reader.TrimLeadingSpace = true | 	csvReader.TrimLeadingSpace = true | ||||||
| 	records, err := csv_reader.ReadAll() | 	records, err := csvReader.ReadAll() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err) | 		log.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err) | ||||||
| 		return | 		return | ||||||
|  | @ -91,6 +96,7 @@ func newValidatorImpl(domains []string, usersFile string, | ||||||
| 	return validator | 	return validator | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // NewValidator constructs a function to validate email addresses
 | ||||||
| func NewValidator(domains []string, usersFile string) func(string) bool { | func NewValidator(domains []string, usersFile string) func(string) bool { | ||||||
| 	return newValidatorImpl(domains, usersFile, nil, func() {}) | 	return newValidatorImpl(domains, usersFile, nil, func() {}) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -8,15 +8,15 @@ import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type ValidatorTest struct { | type ValidatorTest struct { | ||||||
| 	auth_email_file *os.File | 	authEmailFile *os.File | ||||||
| 	done            chan bool | 	done          chan bool | ||||||
| 	update_seen     bool | 	updateSeen    bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewValidatorTest(t *testing.T) *ValidatorTest { | func NewValidatorTest(t *testing.T) *ValidatorTest { | ||||||
| 	vt := &ValidatorTest{} | 	vt := &ValidatorTest{} | ||||||
| 	var err error | 	var err error | ||||||
| 	vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") | 	vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal("failed to create temp file: " + err.Error()) | 		t.Fatal("failed to create temp file: " + err.Error()) | ||||||
| 	} | 	} | ||||||
|  | @ -26,27 +26,27 @@ func NewValidatorTest(t *testing.T) *ValidatorTest { | ||||||
| 
 | 
 | ||||||
| func (vt *ValidatorTest) TearDown() { | func (vt *ValidatorTest) TearDown() { | ||||||
| 	vt.done <- true | 	vt.done <- true | ||||||
| 	os.Remove(vt.auth_email_file.Name()) | 	os.Remove(vt.authEmailFile.Name()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (vt *ValidatorTest) NewValidator(domains []string, | func (vt *ValidatorTest) NewValidator(domains []string, | ||||||
| 	updated chan<- bool) func(string) bool { | 	updated chan<- bool) func(string) bool { | ||||||
| 	return newValidatorImpl(domains, vt.auth_email_file.Name(), | 	return newValidatorImpl(domains, vt.authEmailFile.Name(), | ||||||
| 		vt.done, func() { | 		vt.done, func() { | ||||||
| 			if vt.update_seen == false { | 			if vt.updateSeen == false { | ||||||
| 				updated <- true | 				updated <- true | ||||||
| 				vt.update_seen = true | 				vt.updateSeen = true | ||||||
| 			} | 			} | ||||||
| 		}) | 		}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // This will close vt.auth_email_file.
 | // This will close vt.authEmailFile.
 | ||||||
| func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) { | func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) { | ||||||
| 	defer vt.auth_email_file.Close() | 	defer vt.authEmailFile.Close() | ||||||
| 	vt.auth_email_file.WriteString(strings.Join(emails, "\n")) | 	vt.authEmailFile.WriteString(strings.Join(emails, "\n")) | ||||||
| 	if err := vt.auth_email_file.Close(); err != nil { | 	if err := vt.authEmailFile.Close(); err != nil { | ||||||
| 		t.Fatal("failed to close temp file " + | 		t.Fatal("failed to close temp file " + | ||||||
| 			vt.auth_email_file.Name() + ": " + err.Error()) | 			vt.authEmailFile.Name() + ": " + err.Error()) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -12,18 +12,18 @@ import ( | ||||||
| 
 | 
 | ||||||
| func (vt *ValidatorTest) UpdateEmailFileViaCopyingOver( | func (vt *ValidatorTest) UpdateEmailFileViaCopyingOver( | ||||||
| 	t *testing.T, emails []string) { | 	t *testing.T, emails []string) { | ||||||
| 	orig_file := vt.auth_email_file | 	origFile := vt.authEmailFile | ||||||
| 	var err error | 	var err error | ||||||
| 	vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") | 	vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal("failed to create temp file for copy: " + err.Error()) | 		t.Fatal("failed to create temp file for copy: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	vt.WriteEmails(t, emails) | 	vt.WriteEmails(t, emails) | ||||||
| 	err = os.Rename(vt.auth_email_file.Name(), orig_file.Name()) | 	err = os.Rename(vt.authEmailFile.Name(), origFile.Name()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal("failed to copy over temp file: " + err.Error()) | 		t.Fatal("failed to copy over temp file: " + err.Error()) | ||||||
| 	} | 	} | ||||||
| 	vt.auth_email_file = orig_file | 	vt.authEmailFile = origFile | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) { | func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) { | ||||||
|  |  | ||||||
|  | @ -10,8 +10,8 @@ import ( | ||||||
| 
 | 
 | ||||||
| func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) { | func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) { | ||||||
| 	var err error | 	var err error | ||||||
| 	vt.auth_email_file, err = os.OpenFile( | 	vt.authEmailFile, err = os.OpenFile( | ||||||
| 		vt.auth_email_file.Name(), os.O_WRONLY|os.O_CREATE, 0600) | 		vt.authEmailFile.Name(), os.O_WRONLY|os.O_CREATE, 0600) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal("failed to re-open temp file for updates") | 		t.Fatal("failed to re-open temp file for updates") | ||||||
| 	} | 	} | ||||||
|  | @ -20,24 +20,24 @@ func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) { | ||||||
| 
 | 
 | ||||||
| func (vt *ValidatorTest) UpdateEmailFileViaRenameAndReplace( | func (vt *ValidatorTest) UpdateEmailFileViaRenameAndReplace( | ||||||
| 	t *testing.T, emails []string) { | 	t *testing.T, emails []string) { | ||||||
| 	orig_file := vt.auth_email_file | 	origFile := vt.authEmailFile | ||||||
| 	var err error | 	var err error | ||||||
| 	vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") | 	vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal("failed to create temp file for rename and replace: " + | 		t.Fatal("failed to create temp file for rename and replace: " + | ||||||
| 			err.Error()) | 			err.Error()) | ||||||
| 	} | 	} | ||||||
| 	vt.WriteEmails(t, emails) | 	vt.WriteEmails(t, emails) | ||||||
| 
 | 
 | ||||||
| 	moved_name := orig_file.Name() + "-moved" | 	movedName := origFile.Name() + "-moved" | ||||||
| 	err = os.Rename(orig_file.Name(), moved_name) | 	err = os.Rename(origFile.Name(), movedName) | ||||||
| 	err = os.Rename(vt.auth_email_file.Name(), orig_file.Name()) | 	err = os.Rename(vt.authEmailFile.Name(), origFile.Name()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal("failed to rename and replace temp file: " + | 		t.Fatal("failed to rename and replace temp file: " + | ||||||
| 			err.Error()) | 			err.Error()) | ||||||
| 	} | 	} | ||||||
| 	vt.auth_email_file = orig_file | 	vt.authEmailFile = origFile | ||||||
| 	os.Remove(moved_name) | 	os.Remove(movedName) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestValidatorOverwriteEmailListDirectly(t *testing.T) { | func TestValidatorOverwriteEmailListDirectly(t *testing.T) { | ||||||
|  |  | ||||||
|  | @ -1,3 +1,4 @@ | ||||||
| package main | package main | ||||||
| 
 | 
 | ||||||
| const VERSION = "2.2.1-alpha" | // VERSION contains version information
 | ||||||
|  | var VERSION = "undefined" | ||||||
|  |  | ||||||
							
								
								
									
										13
									
								
								watcher.go
								
								
								
								
							
							
						
						
									
										13
									
								
								watcher.go
								
								
								
								
							|  | @ -8,16 +8,18 @@ import ( | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"gopkg.in/fsnotify.v1" | 	fsnotify "gopkg.in/fsnotify/fsnotify.v1" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | // WaitForReplacement waits for a file to exist on disk and then starts a watch
 | ||||||
|  | // for the file
 | ||||||
| func WaitForReplacement(filename string, op fsnotify.Op, | func WaitForReplacement(filename string, op fsnotify.Op, | ||||||
| 	watcher *fsnotify.Watcher) { | 	watcher *fsnotify.Watcher) { | ||||||
| 	const sleep_interval = 50 * time.Millisecond | 	const sleepInterval = 50 * time.Millisecond | ||||||
| 
 | 
 | ||||||
| 	// Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod.
 | 	// Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod.
 | ||||||
| 	if op&fsnotify.Chmod != 0 { | 	if op&fsnotify.Chmod != 0 { | ||||||
| 		time.Sleep(sleep_interval) | 		time.Sleep(sleepInterval) | ||||||
| 	} | 	} | ||||||
| 	for { | 	for { | ||||||
| 		if _, err := os.Stat(filename); err == nil { | 		if _, err := os.Stat(filename); err == nil { | ||||||
|  | @ -26,10 +28,11 @@ func WaitForReplacement(filename string, op fsnotify.Op, | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		time.Sleep(sleep_interval) | 		time.Sleep(sleepInterval) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // WatchForUpdates performs an action every time a file on disk is updated
 | ||||||
| func WatchForUpdates(filename string, done <-chan bool, action func()) { | func WatchForUpdates(filename string, done <-chan bool, action func()) { | ||||||
| 	filename = filepath.Clean(filename) | 	filename = filepath.Clean(filename) | ||||||
| 	watcher, err := fsnotify.NewWatcher() | 	watcher, err := fsnotify.NewWatcher() | ||||||
|  | @ -56,7 +59,7 @@ func WatchForUpdates(filename string, done <-chan bool, action func()) { | ||||||
| 				} | 				} | ||||||
| 				log.Printf("reloading after event: %s", event) | 				log.Printf("reloading after event: %s", event) | ||||||
| 				action() | 				action() | ||||||
| 			case err := <-watcher.Errors: | 			case err = <-watcher.Errors: | ||||||
| 				log.Printf("error watching %s: %s", filename, err) | 				log.Printf("error watching %s: %s", filename, err) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue