Merge remote-tracking branch 'origin2/master'
This commit is contained in:
commit
e7919f0535
35
CHANGELOG.md
35
CHANGELOG.md
|
|
@ -4,7 +4,21 @@
|
||||||
|
|
||||||
## Important Notes
|
## Important Notes
|
||||||
|
|
||||||
|
- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) Redirect URL generation will attempt secondary strategies
|
||||||
|
in the priority chain if any fail the `IsValidRedirect` security check. Previously any failures fell back to `/`.
|
||||||
|
- [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Keycloak will now use `--profile-url` if set for the userinfo endpoint
|
||||||
|
instead of `--validate-url`. `--validate-url` will still work for backwards compatibility.
|
||||||
|
- [#957](https://github.com/oauth2-proxy/oauth2-proxy/pull/957) To use X-Forwarded-{Proto,Host,Uri} on redirect detection, `--reverse-proxy` must be `true`.
|
||||||
|
- [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) `--user-id-claim` option is deprecated and replaced by `--oidc-email-claim`
|
||||||
- [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Gitlab projects needs a Gitlab application with the extra `read_api` enabled
|
- [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Gitlab projects needs a Gitlab application with the extra `read_api` enabled
|
||||||
|
- [#849](https://github.com/oauth2-proxy/oauth2-proxy/pull/849) `/oauth2/auth` `allowed_groups` querystring parameter can be paired with the `allowed-groups` configuration option.
|
||||||
|
- The `allowed_groups` querystring parameter can specify multiple comma delimited groups.
|
||||||
|
- In this scenario, the user must have a group (from their multiple groups) present in both lists to not get a 401 or 403 response code.
|
||||||
|
- Example:
|
||||||
|
- OAuth2-Proxy globally sets the `allowed_groups` as `engineering`.
|
||||||
|
- An application using Kubernetes ingress uses the `/oauth2/auth` endpoint with `allowed_groups` querystring set to `backend`.
|
||||||
|
- A user must have a session with the groups `["engineering", "backend"]` to pass authorization.
|
||||||
|
- Another user with the groups `["engineering", "frontend"]` would fail the querystring authorization portion.
|
||||||
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication.
|
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication.
|
||||||
- [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped.
|
- [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped.
|
||||||
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) The behavior of the Google provider Groups restriction changes with this
|
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) The behavior of the Google provider Groups restriction changes with this
|
||||||
|
|
@ -18,11 +32,19 @@
|
||||||
- [#575](https://github.com/oauth2-proxy/oauth2-proxy/pull/575) Sessions from v5.1.1 or earlier will no longer validate since they were not signed with SHA1.
|
- [#575](https://github.com/oauth2-proxy/oauth2-proxy/pull/575) Sessions from v5.1.1 or earlier will no longer validate since they were not signed with SHA1.
|
||||||
- Sessions from v6.0.0 or later had a graceful conversion to SHA256 that resulted in no reauthentication
|
- Sessions from v6.0.0 or later had a graceful conversion to SHA256 that resulted in no reauthentication
|
||||||
- Upgrading from v5.1.1 or earlier will result in a reauthentication
|
- Upgrading from v5.1.1 or earlier will result in a reauthentication
|
||||||
- [#616](https://github.com/oauth2-proxy/oauth2-proxy/pull/616) Ensure you have configured oauth2-proxy to use the `groups` scope. The user may be logged out initially as they may not currently have the `groups` claim however after going back through login process wil be authenticated.
|
- [#616](https://github.com/oauth2-proxy/oauth2-proxy/pull/616) Ensure you have configured oauth2-proxy to use the `groups` scope.
|
||||||
|
- The user may be logged out initially as they may not currently have the `groups` claim however after going back through login process wil be authenticated.
|
||||||
- [#839](https://github.com/oauth2-proxy/oauth2-proxy/pull/839) Enables complex data structures for group claim entries, which are output as Json by default.
|
- [#839](https://github.com/oauth2-proxy/oauth2-proxy/pull/839) Enables complex data structures for group claim entries, which are output as Json by default.
|
||||||
|
|
||||||
## Breaking Changes
|
## Breaking Changes
|
||||||
|
|
||||||
|
- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) `--reverse-proxy` must be true to trust `X-Forwarded-*` headers as canonical.
|
||||||
|
These are used throughout the application in redirect URLs, cookie domains and host logging logic. These are the headers:
|
||||||
|
- `X-Forwarded-Proto` instead of `req.URL.Scheme`
|
||||||
|
- `X-Forwarded-Host` instead of `req.Host`
|
||||||
|
- `X-Forwarded-Uri` instead of `req.URL.RequestURI()`
|
||||||
|
- [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) In config files & envvar configs, `keycloak_group` is now the plural `keycloak_groups`.
|
||||||
|
Flag configs are still `--keycloak-group` but it can be passed multiple times.
|
||||||
- [#911](https://github.com/oauth2-proxy/oauth2-proxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google".
|
- [#911](https://github.com/oauth2-proxy/oauth2-proxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google".
|
||||||
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Security changes to Google provider group authorization flow
|
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Security changes to Google provider group authorization flow
|
||||||
- If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately.
|
- If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately.
|
||||||
|
|
@ -44,15 +66,23 @@
|
||||||
|
|
||||||
## Changes since v6.1.1
|
## Changes since v6.1.1
|
||||||
|
|
||||||
|
- [#995](https://github.com/oauth2-proxy/oauth2-proxy/pull/995) Add Security Policy (@JoelSpeed)
|
||||||
|
- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) Require `--reverse-proxy` true to trust `X-Forwareded-*` type headers (@NickMeves)
|
||||||
|
- [#970](https://github.com/oauth2-proxy/oauth2-proxy/pull/970) Fix joined cookie name for those containing underline in the suffix (@peppered)
|
||||||
|
- [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Migrate Keycloak to EnrichSession & support multiple groups for authorization (@NickMeves)
|
||||||
|
- [#957](https://github.com/oauth2-proxy/oauth2-proxy/pull/957) Use X-Forwarded-{Proto,Host,Uri} on redirect as last resort (@linuxgemini)
|
||||||
- [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Add support for Gitlab project based authentication (@factorysh)
|
- [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Add support for Gitlab project based authentication (@factorysh)
|
||||||
- [#907](https://github.com/oauth2-proxy/oauth2-proxy/pull/907) Introduce alpha configuration option to enable testing of structured configuration (@JoelSpeed)
|
- [#907](https://github.com/oauth2-proxy/oauth2-proxy/pull/907) Introduce alpha configuration option to enable testing of structured configuration (@JoelSpeed)
|
||||||
- [#938](https://github.com/oauth2-proxy/oauth2-proxy/pull/938) Cleanup missed provider renaming refactor methods (@NickMeves)
|
- [#938](https://github.com/oauth2-proxy/oauth2-proxy/pull/938) Cleanup missed provider renaming refactor methods (@NickMeves)
|
||||||
|
- [#816](https://github.com/oauth2-proxy/oauth2-proxy/pull/816) (via [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936)) Support non-list group claims (@loafoe)
|
||||||
|
- [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) Refactor OIDC Provider and support groups from Profile URL (@NickMeves)
|
||||||
|
- [#869](https://github.com/oauth2-proxy/oauth2-proxy/pull/869) Streamline provider interface method names and signatures (@NickMeves)
|
||||||
|
- [#849](https://github.com/oauth2-proxy/oauth2-proxy/pull/849) Support group authorization on `oauth2/auth` endpoint via `allowed_groups` querystring (@NickMeves)
|
||||||
- [#925](https://github.com/oauth2-proxy/oauth2-proxy/pull/925) Fix basic auth legacy header conversion (@JoelSpeed)
|
- [#925](https://github.com/oauth2-proxy/oauth2-proxy/pull/925) Fix basic auth legacy header conversion (@JoelSpeed)
|
||||||
- [#916](https://github.com/oauth2-proxy/oauth2-proxy/pull/916) Add AlphaOptions struct to prepare for alpha config loading (@JoelSpeed)
|
- [#916](https://github.com/oauth2-proxy/oauth2-proxy/pull/916) Add AlphaOptions struct to prepare for alpha config loading (@JoelSpeed)
|
||||||
- [#923](https://github.com/oauth2-proxy/oauth2-proxy/pull/923) Support TLS 1.3 (@aajisaka)
|
- [#923](https://github.com/oauth2-proxy/oauth2-proxy/pull/923) Support TLS 1.3 (@aajisaka)
|
||||||
- [#918](https://github.com/oauth2-proxy/oauth2-proxy/pull/918) Fix log header output (@JoelSpeed)
|
- [#918](https://github.com/oauth2-proxy/oauth2-proxy/pull/918) Fix log header output (@JoelSpeed)
|
||||||
- [#911](https://github.com/oauth2-proxy/oauth2-proxy/pull/911) Validate provider type on startup.
|
- [#911](https://github.com/oauth2-proxy/oauth2-proxy/pull/911) Validate provider type on startup.
|
||||||
- [#869](https://github.com/oauth2-proxy/oauth2-proxy/pull/869) Streamline provider interface method names and signatures (@NickMeves)
|
|
||||||
- [#906](https://github.com/oauth2-proxy/oauth2-proxy/pull/906) Set up v6.1.x versioned documentation as default documentation (@JoelSpeed)
|
- [#906](https://github.com/oauth2-proxy/oauth2-proxy/pull/906) Set up v6.1.x versioned documentation as default documentation (@JoelSpeed)
|
||||||
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Remove v5 legacy sessions support (@NickMeves)
|
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Remove v5 legacy sessions support (@NickMeves)
|
||||||
- [#904](https://github.com/oauth2-proxy/oauth2-proxy/pull/904) Set `skip-auth-strip-headers` to `true` by default (@NickMeves)
|
- [#904](https://github.com/oauth2-proxy/oauth2-proxy/pull/904) Set `skip-auth-strip-headers` to `true` by default (@NickMeves)
|
||||||
|
|
@ -79,6 +109,7 @@
|
||||||
- [#750](https://github.com/oauth2-proxy/oauth2-proxy/pull/750) ci: Migrate to Github Actions (@shinebayar-g)
|
- [#750](https://github.com/oauth2-proxy/oauth2-proxy/pull/750) ci: Migrate to Github Actions (@shinebayar-g)
|
||||||
- [#829](https://github.com/oauth2-proxy/oauth2-proxy/pull/820) Rename test directory to testdata (@johejo)
|
- [#829](https://github.com/oauth2-proxy/oauth2-proxy/pull/820) Rename test directory to testdata (@johejo)
|
||||||
- [#819](https://github.com/oauth2-proxy/oauth2-proxy/pull/819) Improve CI (@johejo)
|
- [#819](https://github.com/oauth2-proxy/oauth2-proxy/pull/819) Improve CI (@johejo)
|
||||||
|
- [#989](https://github.com/oauth2-proxy/oauth2-proxy/pull/989) Adapt isAjax to support mimetype lists (@rassie)
|
||||||
|
|
||||||
# v6.1.1
|
# v6.1.1
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,3 @@
|
||||||
Joel Speed <joel.speed@hotmail.co.uk> (@JoelSpeed)
|
Joel Speed <joel.speed@hotmail.co.uk> (@JoelSpeed)
|
||||||
Henry Jenkins <henry@henryjenkins.name> (@steakunderscore)
|
Henry Jenkins <henry@henryjenkins.name> (@steakunderscore)
|
||||||
|
Nick Meves <nick.meves@greenhouse.io> (@NickMeves)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
# Security Disclosures
|
||||||
|
|
||||||
|
Please see [our community docs](https://oauth2-proxy.github.io/oauth2-proxy/docs/community/security) for our security policy.
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
---
|
||||||
|
id: security
|
||||||
|
title: Security
|
||||||
|
---
|
||||||
|
|
||||||
|
:::note
|
||||||
|
OAuth2 Proxy is a community project.
|
||||||
|
Maintainers do not work on this project full time, and as such,
|
||||||
|
while we endeavour to respond to disclosures as quickly as possible,
|
||||||
|
this may take longer than in projects with corporate sponsorship.
|
||||||
|
:::
|
||||||
|
|
||||||
|
## Security Disclosures
|
||||||
|
|
||||||
|
:::important
|
||||||
|
If you believe you have found a vulnerability within OAuth2 Proxy or any of its
|
||||||
|
dependencies, please do NOT open an issue or PR on GitHub, please do NOT post
|
||||||
|
any details publicly.
|
||||||
|
:::
|
||||||
|
|
||||||
|
Security disclosures MUST be done in private.
|
||||||
|
If you have found an issue that you would like to bring to the attention of the
|
||||||
|
maintenance team for OAuth2 Proxy, please compose an email and send it to the
|
||||||
|
list of maintainers in our [MAINTAINERS](https://github.com/oauth2-proxy/oauth2-proxy/blob/master/MAINTAINERS) file.
|
||||||
|
|
||||||
|
Please include as much detail as possible.
|
||||||
|
Ideally, your disclosure should include:
|
||||||
|
- A reproducible case that can be used to demonstrate the exploit
|
||||||
|
- How you discovered this vulnerability
|
||||||
|
- A potential fix for the issue (if you have thought of one)
|
||||||
|
- Versions affected (if not present in master)
|
||||||
|
- Your GitHub ID
|
||||||
|
|
||||||
|
### How will we respond to disclosures?
|
||||||
|
|
||||||
|
We use [GitHub Security Advisories](https://docs.github.com/en/github/managing-security-vulnerabilities/about-github-security-advisories)
|
||||||
|
to privately discuss fixes for disclosed vulnerabilities.
|
||||||
|
If you include a GitHub ID with your disclosure we will add you as a collaborator
|
||||||
|
for the advisory so that you can join the discussion and validate any fixes
|
||||||
|
we may propose.
|
||||||
|
|
||||||
|
For minor issues and previously disclosed vulnerabilities (typically for
|
||||||
|
dependencies), we may use regular PRs for fixes and forego the security advisory.
|
||||||
|
|
||||||
|
Once a fix has been agreed upon, we will merge the fix and create a new release.
|
||||||
|
If we have multiple security issues in flight simultaneously, we may delay
|
||||||
|
merging fixes until all patches are ready.
|
||||||
|
We may also backport the fix to previous releases,
|
||||||
|
but this will be at the discretion of the maintainers.
|
||||||
|
|
@ -135,15 +135,25 @@ If you are using GitHub enterprise, make sure you set the following to the appro
|
||||||
|
|
||||||
Make sure you set the following to the appropriate url:
|
Make sure you set the following to the appropriate url:
|
||||||
|
|
||||||
-provider=keycloak
|
--provider=keycloak
|
||||||
-client-id=<client you have created>
|
--client-id=<client you have created>
|
||||||
-client-secret=<your client's secret>
|
--client-secret=<your client's secret>
|
||||||
-login-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/auth"
|
--login-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/auth"
|
||||||
-redeem-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/token"
|
--redeem-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/token"
|
||||||
-validate-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/userinfo"
|
--profile-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/userinfo"
|
||||||
-keycloak-group=<user_group>
|
--validate-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/userinfo"
|
||||||
|
--keycloak-group=<first_allowed_user_group>
|
||||||
|
--keycloak-group=<second_allowed_user_group>
|
||||||
|
|
||||||
The group management in keycloak is using a tree. If you create a group named admin in keycloak you should define the 'keycloak-group' value to /admin.
|
For group based authorization, the optional `--keycloak-group` (legacy) or `--allowed-group` (global standard)
|
||||||
|
flags can be used to specify which groups to limit access to.
|
||||||
|
|
||||||
|
If these are unset but a `groups` mapper is set up above in step (3), the provider will still
|
||||||
|
populate the `X-Forwarded-Groups` header to your upstream server with the `groups` data in the
|
||||||
|
Keycloak userinfo endpoint response.
|
||||||
|
|
||||||
|
The group management in keycloak is using a tree. If you create a group named admin in keycloak
|
||||||
|
you should define the 'keycloak-group' value to /admin.
|
||||||
|
|
||||||
### GitLab Auth Provider
|
### GitLab Auth Provider
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,8 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/
|
||||||
| `--insecure-oidc-skip-issuer-verification` | bool | allow the OIDC issuer URL to differ from the expected (currently required for Azure multi-tenant compatibility) | false |
|
| `--insecure-oidc-skip-issuer-verification` | bool | allow the OIDC issuer URL to differ from the expected (currently required for Azure multi-tenant compatibility) | false |
|
||||||
| `--oidc-issuer-url` | string | the OpenID Connect issuer URL, e.g. `"https://accounts.google.com"` | |
|
| `--oidc-issuer-url` | string | the OpenID Connect issuer URL, e.g. `"https://accounts.google.com"` | |
|
||||||
| `--oidc-jwks-url` | string | OIDC JWKS URI for token verification; required if OIDC discovery is disabled | |
|
| `--oidc-jwks-url` | string | OIDC JWKS URI for token verification; required if OIDC discovery is disabled | |
|
||||||
| `--oidc-groups-claim` | string | which claim contains the user groups | `"groups"` |
|
| `--oidc-email-claim` | string | which OIDC claim contains the user's email | `"email"` |
|
||||||
|
| `--oidc-groups-claim` | string | which OIDC claim contains the user groups | `"groups"` |
|
||||||
| `--pass-access-token` | bool | pass OAuth access_token to upstream via X-Forwarded-Access-Token header. When used with `--set-xauthrequest` this adds the X-Auth-Request-Access-Token header to the response | false |
|
| `--pass-access-token` | bool | pass OAuth access_token to upstream via X-Forwarded-Access-Token header. When used with `--set-xauthrequest` this adds the X-Auth-Request-Access-Token header to the response | false |
|
||||||
| `--pass-authorization-header` | bool | pass OIDC IDToken to upstream via Authorization Bearer header | false |
|
| `--pass-authorization-header` | bool | pass OIDC IDToken to upstream via Authorization Bearer header | false |
|
||||||
| `--pass-basic-auth` | bool | pass HTTP Basic Auth, X-Forwarded-User, X-Forwarded-Email and X-Forwarded-Preferred-Username information to upstream | true |
|
| `--pass-basic-auth` | bool | pass HTTP Basic Auth, X-Forwarded-User, X-Forwarded-Email and X-Forwarded-Preferred-Username information to upstream | true |
|
||||||
|
|
@ -105,7 +106,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/
|
||||||
| `--request-logging` | bool | Log requests | true |
|
| `--request-logging` | bool | Log requests | true |
|
||||||
| `--request-logging-format` | string | Template for request log lines | see [Logging Configuration](#logging-configuration) |
|
| `--request-logging-format` | string | Template for request log lines | see [Logging Configuration](#logging-configuration) |
|
||||||
| `--resource` | string | The resource that is protected (Azure AD only) | |
|
| `--resource` | string | The resource that is protected (Azure AD only) | |
|
||||||
| `--reverse-proxy` | bool | are we running behind a reverse proxy, controls whether headers like X-Real-IP are accepted | false |
|
| `--reverse-proxy` | bool | are we running behind a reverse proxy, controls whether headers like X-Real-IP are accepted and allows X-Forwarded-{Proto,Host,Uri} headers to be used on redirect selection | false |
|
||||||
| `--scope` | string | OAuth scope specification | |
|
| `--scope` | string | OAuth scope specification | |
|
||||||
| `--session-cookie-minimal` | bool | strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only) | false |
|
| `--session-cookie-minimal` | bool | strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only) | false |
|
||||||
| `--session-store-type` | string | [Session data storage backend](sessions.md); redis or cookie | cookie |
|
| `--session-store-type` | string | [Session data storage backend](sessions.md); redis or cookie | cookie |
|
||||||
|
|
@ -128,7 +129,6 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/
|
||||||
| `--tls-cert-file` | string | path to certificate file | |
|
| `--tls-cert-file` | string | path to certificate file | |
|
||||||
| `--tls-key-file` | string | path to private key file | |
|
| `--tls-key-file` | string | path to private key file | |
|
||||||
| `--upstream` | string \| list | the http url(s) of the upstream endpoint, file:// paths for static files or `static://<status_code>` for static response. Routing is based on the path | |
|
| `--upstream` | string \| list | the http url(s) of the upstream endpoint, file:// paths for static files or `static://<status_code>` for static response. Routing is based on the path | |
|
||||||
| `--user-id-claim` | string | which claim contains the user ID | \["email"\] |
|
|
||||||
| `--allowed-group` | string \| list | restrict logins to members of this group (may be given multiple times) | |
|
| `--allowed-group` | string \| list | restrict logins to members of this group (may be given multiple times) | |
|
||||||
| `--validate-url` | string | Access token validation endpoint | |
|
| `--validate-url` | string | Access token validation endpoint | |
|
||||||
| `--version` | n/a | print version string | |
|
| `--version` | n/a | print version string | |
|
||||||
|
|
@ -354,6 +354,73 @@ It is recommended to use `--session-store-type=redis` when expecting large sessi
|
||||||
|
|
||||||
You have to substitute *name* with the actual cookie name you configured via --cookie-name parameter. If you don't set a custom cookie name the variable should be "$upstream_cookie__oauth2_proxy_1" instead of "$upstream_cookie_name_1" and the new cookie-name should be "_oauth2_proxy_1=" instead of "name_1=".
|
You have to substitute *name* with the actual cookie name you configured via --cookie-name parameter. If you don't set a custom cookie name the variable should be "$upstream_cookie__oauth2_proxy_1" instead of "$upstream_cookie_name_1" and the new cookie-name should be "_oauth2_proxy_1=" instead of "name_1=".
|
||||||
|
|
||||||
|
## Configuring for use with the Traefik (v2) `ForwardAuth` middleware
|
||||||
|
|
||||||
|
**This option requires `--reverse-proxy` option to be set.**
|
||||||
|
|
||||||
|
The [Traefik v2 `ForwardAuth` middleware](https://doc.traefik.io/traefik/middlewares/forwardauth/) allows Traefik to authenticate requests via the oauth2-proxy's `/oauth2/auth` endpoint on every request, which only returns a 202 Accepted response or a 401 Unauthorized response without proxying the whole request through. For example, on Dynamic File (YAML) Configuration:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
http:
|
||||||
|
routers:
|
||||||
|
a-service:
|
||||||
|
rule: "Host(`a-service.example.com`)"
|
||||||
|
service: a-service-backend
|
||||||
|
middlewares:
|
||||||
|
- oauth-errors
|
||||||
|
- oauth-auth
|
||||||
|
tls:
|
||||||
|
certResolver: default
|
||||||
|
domains:
|
||||||
|
- main: "example.com"
|
||||||
|
sans:
|
||||||
|
- "*.example.com"
|
||||||
|
oauth:
|
||||||
|
rule: "Host(`a-service.example.com`, `oauth.example.com`) && PathPrefix(`/oauth2/`)"
|
||||||
|
middlewares:
|
||||||
|
- auth-headers
|
||||||
|
service: oauth-backend
|
||||||
|
tls:
|
||||||
|
certResolver: default
|
||||||
|
domains:
|
||||||
|
- main: "example.com"
|
||||||
|
sans:
|
||||||
|
- "*.example.com"
|
||||||
|
|
||||||
|
services:
|
||||||
|
a-service-backend:
|
||||||
|
loadBalancer:
|
||||||
|
servers:
|
||||||
|
- url: http://172.16.0.2:7555
|
||||||
|
oauth-backend:
|
||||||
|
loadBalancer:
|
||||||
|
servers:
|
||||||
|
- url: http://172.16.0.1:4180
|
||||||
|
|
||||||
|
middlewares:
|
||||||
|
auth-headers:
|
||||||
|
headers:
|
||||||
|
sslRedirect: true
|
||||||
|
stsSeconds: 315360000
|
||||||
|
browserXssFilter: true
|
||||||
|
contentTypeNosniff: true
|
||||||
|
forceSTSHeader: true
|
||||||
|
sslHost: example.com
|
||||||
|
stsIncludeSubdomains: true
|
||||||
|
stsPreload: true
|
||||||
|
frameDeny: true
|
||||||
|
oauth-auth:
|
||||||
|
forwardAuth:
|
||||||
|
address: https://oauth.example.com/oauth2/auth
|
||||||
|
trustForwardHeader: true
|
||||||
|
oauth-errors:
|
||||||
|
errors:
|
||||||
|
status:
|
||||||
|
- "401-403"
|
||||||
|
service: oauth-backend
|
||||||
|
query: "/oauth2/sign_in"
|
||||||
|
```
|
||||||
|
|
||||||
:::note
|
:::note
|
||||||
If you set up your OAuth2 provider to rotate your client secret, you can use the `client-secret-file` option to reload the secret when it is updated.
|
If you set up your OAuth2 provider to rotate your client secret, you can use the `client-secret-file` option to reload the secret when it is updated.
|
||||||
:::
|
:::
|
||||||
|
|
|
||||||
|
|
@ -20,5 +20,11 @@ module.exports = {
|
||||||
collapsed: false,
|
collapsed: false,
|
||||||
items: ['features/endpoints', 'features/request_signatures'],
|
items: ['features/endpoints', 'features/request_signatures'],
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
type: 'category',
|
||||||
|
label: 'Community',
|
||||||
|
collapsed: false,
|
||||||
|
items: ['community/security'],
|
||||||
|
},
|
||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
---
|
||||||
|
id: security
|
||||||
|
title: Security
|
||||||
|
---
|
||||||
|
|
||||||
|
:::note
|
||||||
|
OAuth2 Proxy is a community project.
|
||||||
|
Maintainers do not work on this project full time, and as such,
|
||||||
|
while we endeavour to respond to disclosures as quickly as possible,
|
||||||
|
this may take longer than in projects with corporate sponsorship.
|
||||||
|
:::
|
||||||
|
|
||||||
|
## Security Disclosures
|
||||||
|
|
||||||
|
:::important
|
||||||
|
If you believe you have found a vulnerability within OAuth2 Proxy or any of its
|
||||||
|
dependencies, please do NOT open an issue or PR on GitHub, please do NOT post any
|
||||||
|
details publicly.
|
||||||
|
:::
|
||||||
|
|
||||||
|
Security disclosures MUST be done in private.
|
||||||
|
If you have found an issue that you would like to bring to the attention of the
|
||||||
|
maintenance team for OAuth2 Proxy, please compose an email and send it to the
|
||||||
|
list of maintainers in our [MAINTAINERS](https://github.com/oauth2-proxy/oauth2-proxy/blob/master/MAINTAINERS) file.
|
||||||
|
|
||||||
|
Please include as much detail as possible.
|
||||||
|
Ideally, your disclosure should include:
|
||||||
|
- A reproducible case that can be used to demonstrate the exploit
|
||||||
|
- How you discovered this vulnerability
|
||||||
|
- A potential fix for the issue (if you have thought of one)
|
||||||
|
- Versions affected (if not present in master)
|
||||||
|
- Your GitHub ID
|
||||||
|
|
||||||
|
### How will we respond to disclosures?
|
||||||
|
|
||||||
|
We use [GitHub Security Advisories](https://docs.github.com/en/github/managing-security-vulnerabilities/about-github-security-advisories)
|
||||||
|
to privately discuss fixes for disclosed vulnerabilities.
|
||||||
|
If you include a GitHub ID with your disclosure we will add you as a collaborator
|
||||||
|
for the advisory so that you can join the discussion and validate any fixes
|
||||||
|
we may propose.
|
||||||
|
|
||||||
|
For minor issues and previously disclosed vulnerabilities (typically for
|
||||||
|
dependencies), we may use regular PRs for fixes and forego the security advisory.
|
||||||
|
|
||||||
|
Once a fix has been agreed upon, we will merge the fix and create a new release.
|
||||||
|
If we have multiple security issues in flight simultaneously, we may delay
|
||||||
|
merging fixes until all patches are ready.
|
||||||
|
We may also backport the fix to previous releases,
|
||||||
|
but this will be at the discretion of the maintainers.
|
||||||
|
|
@ -45,6 +45,17 @@
|
||||||
"id": "version-6.1.x/features/request_signatures"
|
"id": "version-6.1.x/features/request_signatures"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"collapsed": false,
|
||||||
|
"type": "category",
|
||||||
|
"label": "Community",
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"type": "doc",
|
||||||
|
"id": "version-6.1.x/community/security"
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
663
oauthproxy.go
663
oauthproxy.go
|
|
@ -24,16 +24,14 @@ import (
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||||
|
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/upstream"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/upstream"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
httpScheme = "http"
|
schemeHTTPS = "https"
|
||||||
httpsScheme = "https"
|
|
||||||
|
|
||||||
applicationJSON = "application/json"
|
applicationJSON = "application/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -229,7 +227,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
||||||
// the OAuth2 Proxy authentication logic kicks in.
|
// the OAuth2 Proxy authentication logic kicks in.
|
||||||
// For example forcing HTTPS or health checks.
|
// For example forcing HTTPS or health checks.
|
||||||
func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
|
func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
|
||||||
chain := alice.New(middleware.NewScope())
|
chain := alice.New(middleware.NewScope(opts.ReverseProxy))
|
||||||
|
|
||||||
if opts.ForceHTTPS {
|
if opts.ForceHTTPS {
|
||||||
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
|
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
|
||||||
|
|
@ -366,49 +364,6 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) {
|
||||||
return routes, nil
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRedirectURI returns the redirectURL that the upstream OAuth Provider will
|
|
||||||
// redirect clients to once authenticated
|
|
||||||
func (p *OAuthProxy) GetRedirectURI(host string) string {
|
|
||||||
// default to the request Host if not set
|
|
||||||
if p.redirectURL.Host != "" {
|
|
||||||
return p.redirectURL.String()
|
|
||||||
}
|
|
||||||
u := *p.redirectURL
|
|
||||||
if u.Scheme == "" {
|
|
||||||
if p.CookieSecure {
|
|
||||||
u.Scheme = httpsScheme
|
|
||||||
} else {
|
|
||||||
u.Scheme = httpScheme
|
|
||||||
}
|
|
||||||
}
|
|
||||||
u.Host = host
|
|
||||||
return u.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) {
|
|
||||||
if code == "" {
|
|
||||||
return nil, providers.ErrMissingCode
|
|
||||||
}
|
|
||||||
redirectURI := p.GetRedirectURI(host)
|
|
||||||
s, err := p.provider.Redeem(ctx, redirectURI, code)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error {
|
|
||||||
var err error
|
|
||||||
if s.Email == "" {
|
|
||||||
s.Email, err = p.provider.GetEmailAddress(ctx, s)
|
|
||||||
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.provider.EnrichSession(ctx, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MakeCSRFCookie creates a cookie for CSRF
|
// MakeCSRFCookie creates a cookie for CSRF
|
||||||
func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||||
return p.makeCookie(req, p.CSRFCookieName, value, expiration, now)
|
return p.makeCookie(req, p.CSRFCookieName, value, expiration, now)
|
||||||
|
|
@ -418,7 +373,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex
|
||||||
cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains)
|
cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains)
|
||||||
|
|
||||||
if cookieDomain != "" {
|
if cookieDomain != "" {
|
||||||
domain := util.GetRequestHost(req)
|
domain := requestutil.GetRequestHost(req)
|
||||||
if h, _, err := net.SplitHostPort(domain); err == nil {
|
if h, _, err := net.SplitHostPort(domain); err == nil {
|
||||||
domain = h
|
domain = h
|
||||||
}
|
}
|
||||||
|
|
@ -466,6 +421,81 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *s
|
||||||
return p.sessionStore.Save(rw, req, s)
|
return p.sessionStore.Save(rw, req, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsValidRedirect checks whether the redirect URL is whitelisted
|
||||||
|
func (p *OAuthProxy) IsValidRedirect(redirect string) bool {
|
||||||
|
switch {
|
||||||
|
case redirect == "":
|
||||||
|
// The user didn't specify a redirect, should fallback to `/`
|
||||||
|
return false
|
||||||
|
case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect):
|
||||||
|
return true
|
||||||
|
case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"):
|
||||||
|
redirectURL, err := url.Parse(redirect)
|
||||||
|
if err != nil {
|
||||||
|
logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
redirectHostname := redirectURL.Hostname()
|
||||||
|
|
||||||
|
for _, domain := range p.whitelistDomains {
|
||||||
|
domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, "."))
|
||||||
|
if domainHostname == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) {
|
||||||
|
// the domain names match, now validate the ports
|
||||||
|
// if the whitelisted domain's port is '*', allow all ports
|
||||||
|
// if the whitelisted domain contains a specific port, only allow that port
|
||||||
|
// if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https
|
||||||
|
redirectPort := redirectURL.Port()
|
||||||
|
if (domainPort == "*") ||
|
||||||
|
(domainPort == redirectPort) ||
|
||||||
|
(domainPort == "" && redirectPort == "") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect)
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) {
|
||||||
|
prepareNoCache(rw)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch path := req.URL.Path; {
|
||||||
|
case path == p.RobotsPath:
|
||||||
|
p.RobotsTxt(rw)
|
||||||
|
case p.IsAllowedRequest(req):
|
||||||
|
p.SkipAuthProxy(rw, req)
|
||||||
|
case path == p.SignInPath:
|
||||||
|
p.SignIn(rw, req)
|
||||||
|
case path == p.SignOutPath:
|
||||||
|
p.SignOut(rw, req)
|
||||||
|
case path == p.OAuthStartPath:
|
||||||
|
p.OAuthStart(rw, req)
|
||||||
|
case path == p.OAuthCallbackPath:
|
||||||
|
p.OAuthCallback(rw, req)
|
||||||
|
case path == p.AuthOnlyPath:
|
||||||
|
p.AuthOnly(rw, req)
|
||||||
|
case path == p.UserInfoPath:
|
||||||
|
p.UserInfo(rw, req)
|
||||||
|
default:
|
||||||
|
p.Proxy(rw, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// RobotsTxt disallows scraping pages from the OAuthProxy
|
// RobotsTxt disallows scraping pages from the OAuthProxy
|
||||||
func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) {
|
func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) {
|
||||||
_, err := fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
|
_, err := fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
|
||||||
|
|
@ -496,6 +526,42 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsAllowedRequest is used to check if auth should be skipped for this request
|
||||||
|
func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool {
|
||||||
|
isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS"
|
||||||
|
return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.isTrustedIP(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAllowedRoute is used to check if the request method & path is allowed without auth
|
||||||
|
func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool {
|
||||||
|
for _, route := range p.allowedRoutes {
|
||||||
|
if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isTrustedIP is used to check if a request comes from a trusted client IP address.
|
||||||
|
func (p *OAuthProxy) isTrustedIP(req *http.Request) bool {
|
||||||
|
if p.trustedIPs == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Error obtaining real IP for trusted IP list: %v", err)
|
||||||
|
// Possibly spoofed X-Real-IP header
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if remoteAddr == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.trustedIPs.Has(remoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
// SignInPage writes the sing in template to the response
|
// SignInPage writes the sing in template to the response
|
||||||
func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
|
func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
|
||||||
prepareNoCache(rw)
|
prepareNoCache(rw)
|
||||||
|
|
@ -507,7 +573,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
|
||||||
}
|
}
|
||||||
rw.WriteHeader(code)
|
rw.WriteHeader(code)
|
||||||
|
|
||||||
redirectURL, err := p.GetRedirect(req)
|
redirectURL, err := p.getAppRedirect(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Error obtaining redirect: %v", err)
|
logger.Errorf("Error obtaining redirect: %v", err)
|
||||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
||||||
|
|
@ -566,195 +632,9 @@ func (p *OAuthProxy) ManualSignIn(req *http.Request) (string, bool) {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRedirect reads the query parameter to get the URL to redirect clients to
|
|
||||||
// once authenticated with the OAuthProxy
|
|
||||||
func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) {
|
|
||||||
err = req.ParseForm()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
redirect = req.Header.Get("X-Auth-Request-Redirect")
|
|
||||||
if req.Form.Get("rd") != "" {
|
|
||||||
redirect = req.Form.Get("rd")
|
|
||||||
}
|
|
||||||
if !p.IsValidRedirect(redirect) {
|
|
||||||
// Use RequestURI to preserve ?query
|
|
||||||
redirect = req.URL.RequestURI()
|
|
||||||
if strings.HasPrefix(redirect, p.ProxyPrefix) {
|
|
||||||
redirect = "/"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// splitHostPort separates host and port. If the port is not valid, it returns
|
|
||||||
// the entire input as host, and it doesn't check the validity of the host.
|
|
||||||
// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
|
|
||||||
// *** taken from net/url, modified validOptionalPort() to accept ":*"
|
|
||||||
func splitHostPort(hostport string) (host, port string) {
|
|
||||||
host = hostport
|
|
||||||
|
|
||||||
colon := strings.LastIndexByte(host, ':')
|
|
||||||
if colon != -1 && validOptionalPort(host[colon:]) {
|
|
||||||
host, port = host[:colon], host[colon+1:]
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
|
||||||
host = host[1 : len(host)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// validOptionalPort reports whether port is either an empty string
|
|
||||||
// or matches /^:\d*$/
|
|
||||||
// *** taken from net/url, modified to accept ":*"
|
|
||||||
func validOptionalPort(port string) bool {
|
|
||||||
if port == "" || port == ":*" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if port[0] != ':' {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for _, b := range port[1:] {
|
|
||||||
if b < '0' || b > '9' {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsValidRedirect checks whether the redirect URL is whitelisted
|
|
||||||
func (p *OAuthProxy) IsValidRedirect(redirect string) bool {
|
|
||||||
switch {
|
|
||||||
case redirect == "":
|
|
||||||
// The user didn't specify a redirect, should fallback to `/`
|
|
||||||
return false
|
|
||||||
case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect):
|
|
||||||
return true
|
|
||||||
case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"):
|
|
||||||
redirectURL, err := url.Parse(redirect)
|
|
||||||
if err != nil {
|
|
||||||
logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
redirectHostname := redirectURL.Hostname()
|
|
||||||
|
|
||||||
for _, domain := range p.whitelistDomains {
|
|
||||||
domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, "."))
|
|
||||||
if domainHostname == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) {
|
|
||||||
// the domain names match, now validate the ports
|
|
||||||
// if the whitelisted domain's port is '*', allow all ports
|
|
||||||
// if the whitelisted domain contains a specific port, only allow that port
|
|
||||||
// if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https
|
|
||||||
redirectPort := redirectURL.Port()
|
|
||||||
if (domainPort == "*") ||
|
|
||||||
(domainPort == redirectPort) ||
|
|
||||||
(domainPort == "" && redirectPort == "") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect)
|
|
||||||
return false
|
|
||||||
default:
|
|
||||||
logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsAllowedRequest is used to check if auth should be skipped for this request
|
|
||||||
func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool {
|
|
||||||
isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS"
|
|
||||||
return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.IsTrustedIP(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsAllowedRoute is used to check if the request method & path is allowed without auth
|
|
||||||
func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool {
|
|
||||||
for _, route := range p.allowedRoutes {
|
|
||||||
if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
|
|
||||||
var noCacheHeaders = map[string]string{
|
|
||||||
"Expires": time.Unix(0, 0).Format(time.RFC1123),
|
|
||||||
"Cache-Control": "no-cache, no-store, must-revalidate, max-age=0",
|
|
||||||
"X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepareNoCache prepares headers for preventing browser caching.
|
|
||||||
func prepareNoCache(w http.ResponseWriter) {
|
|
||||||
// Set NoCache headers
|
|
||||||
for k, v := range noCacheHeaders {
|
|
||||||
w.Header().Set(k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsTrustedIP is used to check if a request comes from a trusted client IP address.
|
|
||||||
func (p *OAuthProxy) IsTrustedIP(req *http.Request) bool {
|
|
||||||
if p.trustedIPs == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req)
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Error obtaining real IP for trusted IP list: %v", err)
|
|
||||||
// Possibly spoofed X-Real-IP header
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if remoteAddr == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.trustedIPs.Has(remoteAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
|
||||||
if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) {
|
|
||||||
prepareNoCache(rw)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch path := req.URL.Path; {
|
|
||||||
case path == p.RobotsPath:
|
|
||||||
p.RobotsTxt(rw)
|
|
||||||
case p.IsAllowedRequest(req):
|
|
||||||
p.SkipAuthProxy(rw, req)
|
|
||||||
case path == p.SignInPath:
|
|
||||||
p.SignIn(rw, req)
|
|
||||||
case path == p.SignOutPath:
|
|
||||||
p.SignOut(rw, req)
|
|
||||||
case path == p.OAuthStartPath:
|
|
||||||
p.OAuthStart(rw, req)
|
|
||||||
case path == p.OAuthCallbackPath:
|
|
||||||
p.OAuthCallback(rw, req)
|
|
||||||
case path == p.AuthOnlyPath:
|
|
||||||
p.AuthenticateOnly(rw, req)
|
|
||||||
case path == p.UserInfoPath:
|
|
||||||
p.UserInfo(rw, req)
|
|
||||||
default:
|
|
||||||
p.Proxy(rw, req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignIn serves a page prompting users to sign in
|
// SignIn serves a page prompting users to sign in
|
||||||
func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
|
func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
|
||||||
redirect, err := p.GetRedirect(req)
|
redirect, err := p.getAppRedirect(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Error obtaining redirect: %v", err)
|
logger.Errorf("Error obtaining redirect: %v", err)
|
||||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
||||||
|
|
@ -812,7 +692,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
// SignOut sends a response to clear the authentication cookie
|
// SignOut sends a response to clear the authentication cookie
|
||||||
func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
|
func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
|
||||||
redirect, err := p.GetRedirect(req)
|
redirect, err := p.getAppRedirect(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Error obtaining redirect: %v", err)
|
logger.Errorf("Error obtaining redirect: %v", err)
|
||||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
||||||
|
|
@ -837,13 +717,13 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.SetCSRFCookie(rw, req, nonce)
|
p.SetCSRFCookie(rw, req, nonce)
|
||||||
redirect, err := p.GetRedirect(req)
|
redirect, err := p.getAppRedirect(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Error obtaining redirect: %v", err)
|
logger.Errorf("Error obtaining redirect: %v", err)
|
||||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
redirectURI := p.GetRedirectURI(util.GetRequestHost(req))
|
redirectURI := p.getOAuthRedirectURI(req)
|
||||||
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound)
|
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -866,7 +746,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := p.redeemCode(req.Context(), util.GetRequestHost(req), req.Form.Get("code"))
|
session, err := p.redeemCode(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Error redeeming code during OAuth2 callback: %v", err)
|
logger.Errorf("Error redeeming code during OAuth2 callback: %v", err)
|
||||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error")
|
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error")
|
||||||
|
|
@ -925,16 +805,50 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthenticateOnly checks whether the user is currently logged in
|
func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, error) {
|
||||||
func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) {
|
code := req.Form.Get("code")
|
||||||
|
if code == "" {
|
||||||
|
return nil, providers.ErrMissingCode
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURI := p.getOAuthRedirectURI(req)
|
||||||
|
s, err := p.provider.Redeem(req.Context(), redirectURI, code)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error {
|
||||||
|
var err error
|
||||||
|
if s.Email == "" {
|
||||||
|
s.Email, err = p.provider.GetEmailAddress(ctx, s)
|
||||||
|
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.provider.EnrichSession(ctx, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthOnly checks whether the user is currently logged in (both authentication
|
||||||
|
// and optional authorization).
|
||||||
|
func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) {
|
||||||
session, err := p.getAuthenticatedSession(rw, req)
|
session, err := p.getAuthenticatedSession(rw, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(rw, "unauthorized request", http.StatusUnauthorized)
|
http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unauthorized cases need to return 403 to prevent infinite redirects with
|
||||||
|
// subrequest architectures
|
||||||
|
if !authOnlyAuthorize(req, session) {
|
||||||
|
http.Error(rw, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// we are authenticated
|
// we are authenticated
|
||||||
p.addHeadersForProxying(rw, req, session)
|
p.addHeadersForProxying(rw, session)
|
||||||
p.headersChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
p.headersChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
rw.WriteHeader(http.StatusAccepted)
|
rw.WriteHeader(http.StatusAccepted)
|
||||||
})).ServeHTTP(rw, req)
|
})).ServeHTTP(rw, req)
|
||||||
|
|
@ -952,13 +866,13 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
// we are authenticated
|
// we are authenticated
|
||||||
p.addHeadersForProxying(rw, req, session)
|
p.addHeadersForProxying(rw, session)
|
||||||
p.headersChain.Then(p.serveMux).ServeHTTP(rw, req)
|
p.headersChain.Then(p.serveMux).ServeHTTP(rw, req)
|
||||||
case ErrNeedsLogin:
|
case ErrNeedsLogin:
|
||||||
// we need to send the user to a login screen
|
// we need to send the user to a login screen
|
||||||
if isAjax(req) {
|
if isAjax(req) {
|
||||||
// no point redirecting an AJAX request
|
// no point redirecting an AJAX request
|
||||||
p.ErrorJSON(rw, http.StatusUnauthorized)
|
p.errorJSON(rw, http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -977,7 +891,195 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
||||||
p.ErrorPage(rw, http.StatusInternalServerError,
|
p.ErrorPage(rw, http.StatusInternalServerError,
|
||||||
"Internal Error", "Internal Error")
|
"Internal Error", "Internal Error")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
|
||||||
|
var noCacheHeaders = map[string]string{
|
||||||
|
"Expires": time.Unix(0, 0).Format(time.RFC1123),
|
||||||
|
"Cache-Control": "no-cache, no-store, must-revalidate, max-age=0",
|
||||||
|
"X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepareNoCache prepares headers for preventing browser caching.
|
||||||
|
func prepareNoCache(w http.ResponseWriter) {
|
||||||
|
// Set NoCache headers
|
||||||
|
for k, v := range noCacheHeaders {
|
||||||
|
w.Header().Set(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will
|
||||||
|
// redirect clients to once authenticated.
|
||||||
|
// This is usually the OAuthProxy callback URL.
|
||||||
|
func (p *OAuthProxy) getOAuthRedirectURI(req *http.Request) string {
|
||||||
|
// if `p.redirectURL` already has a host, return it
|
||||||
|
if p.redirectURL.Host != "" {
|
||||||
|
return p.redirectURL.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise figure out the scheme + host from the request
|
||||||
|
rd := *p.redirectURL
|
||||||
|
rd.Host = requestutil.GetRequestHost(req)
|
||||||
|
rd.Scheme = requestutil.GetRequestProto(req)
|
||||||
|
|
||||||
|
// If CookieSecure is true, return `https` no matter what
|
||||||
|
// Not all reverse proxies set X-Forwarded-Proto
|
||||||
|
if p.CookieSecure {
|
||||||
|
rd.Scheme = schemeHTTPS
|
||||||
|
}
|
||||||
|
return rd.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAppRedirect determines the full URL or URI path to redirect clients to
|
||||||
|
// once authenticated with the OAuthProxy
|
||||||
|
// Strategy priority (first legal result is used):
|
||||||
|
// - `rd` querysting parameter
|
||||||
|
// - `X-Auth-Request-Redirect` header
|
||||||
|
// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
|
||||||
|
// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
|
||||||
|
// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
|
||||||
|
// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
|
||||||
|
// - `/`
|
||||||
|
func (p *OAuthProxy) getAppRedirect(req *http.Request) (string, error) {
|
||||||
|
err := req.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// These redirect getter functions are strategies ordered by priority
|
||||||
|
// for figuring out the redirect URL.
|
||||||
|
type redirectGetter func(req *http.Request) string
|
||||||
|
for _, rdGetter := range []redirectGetter{
|
||||||
|
p.getRdQuerystringRedirect,
|
||||||
|
p.getXAuthRequestRedirect,
|
||||||
|
p.getXForwardedHeadersRedirect,
|
||||||
|
p.getURIRedirect,
|
||||||
|
} {
|
||||||
|
redirect := rdGetter(req)
|
||||||
|
// Call `p.IsValidRedirect` again here a final time to be safe
|
||||||
|
if redirect != "" && p.IsValidRedirect(redirect) {
|
||||||
|
return redirect, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "/", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isForwardedRequest(req *http.Request) bool {
|
||||||
|
return requestutil.IsProxied(req) &&
|
||||||
|
req.Host != requestutil.GetRequestHost(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OAuthProxy) hasProxyPrefix(path string) bool {
|
||||||
|
return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OAuthProxy) validateRedirect(redirect string, errorFormat string) string {
|
||||||
|
if p.IsValidRedirect(redirect) {
|
||||||
|
return redirect
|
||||||
|
}
|
||||||
|
if redirect != "" {
|
||||||
|
logger.Errorf(errorFormat, redirect)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRdQuerystringRedirect handles this getAppRedirect strategy:
|
||||||
|
// - `rd` querysting parameter
|
||||||
|
func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string {
|
||||||
|
return p.validateRedirect(
|
||||||
|
req.Form.Get("rd"),
|
||||||
|
"Invalid redirect provided in rd querystring parameter: %s",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getXAuthRequestRedirect handles this getAppRedirect strategy:
|
||||||
|
// - `X-Auth-Request-Redirect` Header
|
||||||
|
func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string {
|
||||||
|
return p.validateRedirect(
|
||||||
|
req.Header.Get("X-Auth-Request-Redirect"),
|
||||||
|
"Invalid redirect provided in X-Auth-Request-Redirect header: %s",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getXForwardedHeadersRedirect handles these getAppRedirect strategies:
|
||||||
|
// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
|
||||||
|
// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
|
||||||
|
func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string {
|
||||||
|
if !isForwardedRequest(req) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
uri := requestutil.GetRequestURI(req)
|
||||||
|
if p.hasProxyPrefix(uri) {
|
||||||
|
uri = "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
redirect := fmt.Sprintf(
|
||||||
|
"%s://%s%s",
|
||||||
|
requestutil.GetRequestProto(req),
|
||||||
|
requestutil.GetRequestHost(req),
|
||||||
|
uri,
|
||||||
|
)
|
||||||
|
|
||||||
|
return p.validateRedirect(redirect,
|
||||||
|
"Invalid redirect generated from X-Forwarded-* headers: %s")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getURIRedirect handles these getAppRedirect strategies:
|
||||||
|
// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
|
||||||
|
// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
|
||||||
|
// - `/`
|
||||||
|
func (p *OAuthProxy) getURIRedirect(req *http.Request) string {
|
||||||
|
redirect := p.validateRedirect(
|
||||||
|
requestutil.GetRequestURI(req),
|
||||||
|
"Invalid redirect generated from X-Forwarded-Uri header: %s",
|
||||||
|
)
|
||||||
|
if redirect == "" {
|
||||||
|
redirect = req.URL.RequestURI()
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.hasProxyPrefix(redirect) {
|
||||||
|
return "/"
|
||||||
|
}
|
||||||
|
return redirect
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitHostPort separates host and port. If the port is not valid, it returns
|
||||||
|
// the entire input as host, and it doesn't check the validity of the host.
|
||||||
|
// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
|
||||||
|
// *** taken from net/url, modified validOptionalPort() to accept ":*"
|
||||||
|
func splitHostPort(hostport string) (host, port string) {
|
||||||
|
host = hostport
|
||||||
|
|
||||||
|
colon := strings.LastIndexByte(host, ':')
|
||||||
|
if colon != -1 && validOptionalPort(host[colon:]) {
|
||||||
|
host, port = host[:colon], host[colon+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||||
|
host = host[1 : len(host)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// validOptionalPort reports whether port is either an empty string
|
||||||
|
// or matches /^:\d*$/
|
||||||
|
// *** taken from net/url, modified to accept ":*"
|
||||||
|
func validOptionalPort(port string) bool {
|
||||||
|
if port == "" || port == ":*" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if port[0] != ':' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, b := range port[1:] {
|
||||||
|
if b < '0' || b > '9' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so
|
// getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so
|
||||||
|
|
@ -989,7 +1091,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
|
||||||
var session *sessionsapi.SessionState
|
var session *sessionsapi.SessionState
|
||||||
|
|
||||||
getSession := p.sessionChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
getSession := p.sessionChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
session = middleware.GetRequestScope(req).Session
|
session = middlewareapi.GetRequestScope(req).Session
|
||||||
}))
|
}))
|
||||||
getSession.ServeHTTP(rw, req)
|
getSession.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
|
@ -1016,8 +1118,55 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// authOnlyAuthorize handles special authorization logic that is only done
|
||||||
|
// on the AuthOnly endpoint for use with Nginx subrequest architectures.
|
||||||
|
//
|
||||||
|
// TODO (@NickMeves): This method is a placeholder to be extended but currently
|
||||||
|
// fails the linter. Remove the nolint when functionality expands.
|
||||||
|
//
|
||||||
|
//nolint:S1008
|
||||||
|
func authOnlyAuthorize(req *http.Request, s *sessionsapi.SessionState) bool {
|
||||||
|
// Allow secondary group restrictions based on the `allowed_groups`
|
||||||
|
// querystring parameter
|
||||||
|
if !checkAllowedGroups(req, s) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkAllowedGroups(req *http.Request, s *sessionsapi.SessionState) bool {
|
||||||
|
allowedGroups := extractAllowedGroups(req)
|
||||||
|
if len(allowedGroups) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, group := range s.Groups {
|
||||||
|
if _, ok := allowedGroups[group]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractAllowedGroups(req *http.Request) map[string]struct{} {
|
||||||
|
groups := map[string]struct{}{}
|
||||||
|
|
||||||
|
query := req.URL.Query()
|
||||||
|
for _, allowedGroups := range query["allowed_groups"] {
|
||||||
|
for _, group := range strings.Split(allowedGroups, ",") {
|
||||||
|
if group != "" {
|
||||||
|
groups[group] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return groups
|
||||||
|
}
|
||||||
|
|
||||||
// addHeadersForProxying adds the appropriate headers the request / response for proxying
|
// addHeadersForProxying adds the appropriate headers the request / response for proxying
|
||||||
func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) {
|
func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, session *sessionsapi.SessionState) {
|
||||||
if session.Email == "" {
|
if session.Email == "" {
|
||||||
rw.Header().Set("GAP-Auth", session.User)
|
rw.Header().Set("GAP-Auth", session.User)
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -1029,16 +1178,24 @@ func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Req
|
||||||
func isAjax(req *http.Request) bool {
|
func isAjax(req *http.Request) bool {
|
||||||
acceptValues := req.Header.Values("Accept")
|
acceptValues := req.Header.Values("Accept")
|
||||||
const ajaxReq = applicationJSON
|
const ajaxReq = applicationJSON
|
||||||
for _, v := range acceptValues {
|
// Iterate over multiple Accept headers, i.e.
|
||||||
if v == ajaxReq {
|
// Accept: application/json
|
||||||
|
// Accept: text/plain
|
||||||
|
for _, mimeTypes := range acceptValues {
|
||||||
|
// Iterate over multiple mimetypes in a single header, i.e.
|
||||||
|
// Accept: application/json, text/plain, */*
|
||||||
|
for _, mimeType := range strings.Split(mimeTypes, ",") {
|
||||||
|
mimeType = strings.TrimSpace(mimeType)
|
||||||
|
if mimeType == ajaxReq {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrorJSON returns the error code with an application/json mime type
|
// errorJSON returns the error code with an application/json mime type
|
||||||
func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) {
|
func (p *OAuthProxy) errorJSON(rw http.ResponseWriter, code int) {
|
||||||
rw.Header().Set("Content-Type", applicationJSON)
|
rw.Header().Set("Content-Type", applicationJSON)
|
||||||
rw.WriteHeader(code)
|
rw.WriteHeader(code)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ import (
|
||||||
|
|
||||||
"github.com/coreos/go-oidc"
|
"github.com/coreos/go-oidc"
|
||||||
"github.com/mbland/hmacauth"
|
"github.com/mbland/hmacauth"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
|
|
@ -414,8 +415,9 @@ func Test_redeemCode(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = proxy.redeemCode(context.Background(), "www.example.com", "")
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
assert.Error(t, err)
|
_, err = proxy.redeemCode(req)
|
||||||
|
assert.Equal(t, providers.ErrMissingCode, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_enrichSession(t *testing.T) {
|
func Test_enrichSession(t *testing.T) {
|
||||||
|
|
@ -1197,18 +1199,20 @@ func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
||||||
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
|
func NewAuthOnlyEndpointTest(querystring string, modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
|
||||||
pcTest, err := NewProcessCookieTestWithOptionsModifiers(modifiers...)
|
pcTest, err := NewProcessCookieTestWithOptionsModifiers(modifiers...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
pcTest.req, _ = http.NewRequest("GET",
|
pcTest.req, _ = http.NewRequest(
|
||||||
pcTest.opts.ProxyPrefix+"/auth", nil)
|
"GET",
|
||||||
|
fmt.Sprintf("%s/auth%s", pcTest.opts.ProxyPrefix, querystring),
|
||||||
|
nil)
|
||||||
return pcTest, nil
|
return pcTest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
||||||
test, err := NewAuthOnlyEndpointTest()
|
test, err := NewAuthOnlyEndpointTest("")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -1226,7 +1230,7 @@ func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
||||||
test, err := NewAuthOnlyEndpointTest()
|
test, err := NewAuthOnlyEndpointTest("")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -1234,11 +1238,11 @@ func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
||||||
test.proxy.ServeHTTP(test.rw, test.req)
|
test.proxy.ServeHTTP(test.rw, test.req)
|
||||||
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
||||||
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
||||||
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
|
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
||||||
test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) {
|
test, err := NewAuthOnlyEndpointTest("", func(opts *options.Options) {
|
||||||
opts.Cookie.Expire = time.Duration(24) * time.Hour
|
opts.Cookie.Expire = time.Duration(24) * time.Hour
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -1254,11 +1258,11 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
||||||
test.proxy.ServeHTTP(test.rw, test.req)
|
test.proxy.ServeHTTP(test.rw, test.req)
|
||||||
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
||||||
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
||||||
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
|
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
||||||
test, err := NewAuthOnlyEndpointTest()
|
test, err := NewAuthOnlyEndpointTest("")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
@ -1273,7 +1277,7 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
||||||
test.proxy.ServeHTTP(test.rw, test.req)
|
test.proxy.ServeHTTP(test.rw, test.req)
|
||||||
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
||||||
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
||||||
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
|
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
||||||
|
|
@ -1746,8 +1750,9 @@ func TestRequestSignature(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetRedirect(t *testing.T) {
|
func Test_getAppRedirect(t *testing.T) {
|
||||||
opts := baseTestOptions()
|
opts := baseTestOptions()
|
||||||
|
opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com", ".example.com:8443")
|
||||||
err := validation.Validate(opts)
|
err := validation.Validate(opts)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
require.NotEmpty(t, opts.ProxyPrefix)
|
require.NotEmpty(t, opts.ProxyPrefix)
|
||||||
|
|
@ -1759,28 +1764,144 @@ func TestGetRedirect(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
url string
|
url string
|
||||||
|
headers map[string]string
|
||||||
|
reverseProxy bool
|
||||||
expectedRedirect string
|
expectedRedirect string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "request outside of ProxyPrefix redirects to original URL",
|
name: "request outside of ProxyPrefix redirects to original URL",
|
||||||
url: "/foo/bar",
|
url: "/foo/bar",
|
||||||
|
headers: nil,
|
||||||
|
reverseProxy: false,
|
||||||
expectedRedirect: "/foo/bar",
|
expectedRedirect: "/foo/bar",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "request with query preserves query",
|
name: "request with query preserves query",
|
||||||
url: "/foo?bar",
|
url: "/foo?bar",
|
||||||
|
headers: nil,
|
||||||
|
reverseProxy: false,
|
||||||
expectedRedirect: "/foo?bar",
|
expectedRedirect: "/foo?bar",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "request under ProxyPrefix redirects to root",
|
name: "request under ProxyPrefix redirects to root",
|
||||||
url: proxy.ProxyPrefix + "/foo/bar",
|
url: proxy.ProxyPrefix + "/foo/bar",
|
||||||
|
headers: nil,
|
||||||
|
reverseProxy: false,
|
||||||
expectedRedirect: "/",
|
expectedRedirect: "/",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "proxied request outside of ProxyPrefix redirects to proxied URL",
|
||||||
|
url: "https://oauth.example.com/foo/bar",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Forwarded-Proto": "https",
|
||||||
|
"X-Forwarded-Host": "a-service.example.com",
|
||||||
|
"X-Forwarded-Uri": "/foo/bar",
|
||||||
|
},
|
||||||
|
reverseProxy: true,
|
||||||
|
expectedRedirect: "https://a-service.example.com/foo/bar",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-proxied request with spoofed proxy headers wouldn't redirect",
|
||||||
|
url: "https://oauth.example.com/foo?bar",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Forwarded-Proto": "https",
|
||||||
|
"X-Forwarded-Host": "a-service.example.com",
|
||||||
|
"X-Forwarded-Uri": "/foo/bar",
|
||||||
|
},
|
||||||
|
reverseProxy: false,
|
||||||
|
expectedRedirect: "/foo?bar",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxied request under ProxyPrefix redirects to root",
|
||||||
|
url: "https://oauth.example.com" + proxy.ProxyPrefix + "/foo/bar",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Forwarded-Proto": "https",
|
||||||
|
"X-Forwarded-Host": "a-service.example.com",
|
||||||
|
"X-Forwarded-Uri": proxy.ProxyPrefix + "/foo/bar",
|
||||||
|
},
|
||||||
|
reverseProxy: true,
|
||||||
|
expectedRedirect: "https://a-service.example.com/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxied request with port under ProxyPrefix redirects to root",
|
||||||
|
url: "https://oauth.example.com" + proxy.ProxyPrefix + "/foo/bar",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Forwarded-Proto": "https",
|
||||||
|
"X-Forwarded-Host": "a-service.example.com:8443",
|
||||||
|
"X-Forwarded-Uri": proxy.ProxyPrefix + "/foo/bar",
|
||||||
|
},
|
||||||
|
reverseProxy: true,
|
||||||
|
expectedRedirect: "https://a-service.example.com:8443/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxied request with missing uri header would still redirect to desired redirect",
|
||||||
|
url: "https://oauth.example.com/foo?bar",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Forwarded-Proto": "https",
|
||||||
|
"X-Forwarded-Host": "a-service.example.com",
|
||||||
|
},
|
||||||
|
reverseProxy: true,
|
||||||
|
expectedRedirect: "https://a-service.example.com/foo?bar",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "request with headers proxy not being set (and reverse proxy enabled) would still redirect to desired redirect",
|
||||||
|
url: "https://oauth.example.com/foo?bar",
|
||||||
|
headers: nil,
|
||||||
|
reverseProxy: true,
|
||||||
|
expectedRedirect: "/foo?bar",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxied request with X-Auth-Request-Redirect being set outside of ProxyPrefix redirects to proxied URL",
|
||||||
|
url: "https://oauth.example.com/foo/bar",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Auth-Request-Redirect": "https://a-service.example.com/foo/bar",
|
||||||
|
},
|
||||||
|
reverseProxy: true,
|
||||||
|
expectedRedirect: "https://a-service.example.com/foo/bar",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxied request with rd query string redirects to proxied URL",
|
||||||
|
url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fbar",
|
||||||
|
headers: nil,
|
||||||
|
reverseProxy: false,
|
||||||
|
expectedRedirect: "https://a-service.example.com/foo/bar",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxied request with rd query string and all headers set (and reverse proxy not enabled) redirects to proxied URL on rd query string",
|
||||||
|
url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fjazz",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Auth-Request-Redirect": "https://a-service.example.com/foo/baz",
|
||||||
|
"X-Forwarded-Proto": "http",
|
||||||
|
"X-Forwarded-Host": "another-service.example.com",
|
||||||
|
"X-Forwarded-Uri": "/seasons/greetings",
|
||||||
|
},
|
||||||
|
reverseProxy: false,
|
||||||
|
expectedRedirect: "https://a-service.example.com/foo/jazz",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "proxied request with rd query string and some headers set redirects to proxied URL on rd query string",
|
||||||
|
url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fbaz",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Forwarded-Proto": "https",
|
||||||
|
"X-Forwarded-Host": "another-service.example.com",
|
||||||
|
"X-Forwarded-Uri": "/seasons/greetings",
|
||||||
|
},
|
||||||
|
reverseProxy: true,
|
||||||
|
expectedRedirect: "https://a-service.example.com/foo/baz",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
req, _ := http.NewRequest("GET", tt.url, nil)
|
req, _ := http.NewRequest("GET", tt.url, nil)
|
||||||
redirect, err := proxy.GetRedirect(req)
|
for header, value := range tt.headers {
|
||||||
|
if value != "" {
|
||||||
|
req.Header.Add(header, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req = middleware.AddRequestScope(req, &middleware.RequestScope{
|
||||||
|
ReverseProxy: tt.reverseProxy,
|
||||||
|
})
|
||||||
|
redirect, err := proxy.getAppRedirect(req)
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, tt.expectedRedirect, redirect)
|
assert.Equal(t, tt.expectedRedirect, redirect)
|
||||||
|
|
@ -1848,6 +1969,13 @@ func TestAjaxUnauthorizedRequest2(t *testing.T) {
|
||||||
testAjaxUnauthorizedRequest(t, header)
|
testAjaxUnauthorizedRequest(t, header)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAjaxUnauthorizedRequestAccept1(t *testing.T) {
|
||||||
|
header := make(http.Header)
|
||||||
|
header.Add("Accept", "application/json, text/plain, */*")
|
||||||
|
|
||||||
|
testAjaxUnauthorizedRequest(t, header)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAjaxForbiddendRequest(t *testing.T) {
|
func TestAjaxForbiddendRequest(t *testing.T) {
|
||||||
test, err := newAjaxRequestTest()
|
test, err := newAjaxRequestTest()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -1960,7 +2088,7 @@ func TestGetJwtSession(t *testing.T) {
|
||||||
verifier := oidc.NewVerifier("https://issuer.example.com", keyset,
|
verifier := oidc.NewVerifier("https://issuer.example.com", keyset,
|
||||||
&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true})
|
&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true})
|
||||||
|
|
||||||
test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) {
|
test, err := NewAuthOnlyEndpointTest("", func(opts *options.Options) {
|
||||||
opts.InjectRequestHeaders = []options.Header{
|
opts.InjectRequestHeaders = []options.Header{
|
||||||
{
|
{
|
||||||
Name: "Authorization",
|
Name: "Authorization",
|
||||||
|
|
@ -2028,7 +2156,6 @@ func TestGetJwtSession(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
opts.SkipJwtBearerTokens = true
|
opts.SkipJwtBearerTokens = true
|
||||||
opts.SetJWTBearerVerifiers(append(opts.GetJWTBearerVerifiers(), verifier))
|
opts.SetJWTBearerVerifiers(append(opts.GetJWTBearerVerifiers(), verifier))
|
||||||
})
|
})
|
||||||
|
|
@ -2692,32 +2819,106 @@ func TestProxyAllowedGroups(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthOnlyAllowedGroups(t *testing.T) {
|
func TestAuthOnlyAllowedGroups(t *testing.T) {
|
||||||
tests := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
allowedGroups []string
|
allowedGroups []string
|
||||||
groups []string
|
groups []string
|
||||||
expectUnauthorized bool
|
querystring string
|
||||||
|
expectedStatusCode int
|
||||||
}{
|
}{
|
||||||
{"NoAllowedGroups", []string{}, []string{}, false},
|
{
|
||||||
{"NoAllowedGroupsUserHasGroups", []string{}, []string{"a", "b"}, false},
|
name: "NoAllowedGroups",
|
||||||
{"UserInAllowedGroup", []string{"a"}, []string{"a", "b"}, false},
|
allowedGroups: []string{},
|
||||||
{"UserNotInAllowedGroup", []string{"a"}, []string{"c"}, true},
|
groups: []string{},
|
||||||
|
querystring: "",
|
||||||
|
expectedStatusCode: http.StatusAccepted,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NoAllowedGroupsUserHasGroups",
|
||||||
|
allowedGroups: []string{},
|
||||||
|
groups: []string{"a", "b"},
|
||||||
|
querystring: "",
|
||||||
|
expectedStatusCode: http.StatusAccepted,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UserInAllowedGroup",
|
||||||
|
allowedGroups: []string{"a"},
|
||||||
|
groups: []string{"a", "b"},
|
||||||
|
querystring: "",
|
||||||
|
expectedStatusCode: http.StatusAccepted,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UserNotInAllowedGroup",
|
||||||
|
allowedGroups: []string{"a"},
|
||||||
|
groups: []string{"c"},
|
||||||
|
querystring: "",
|
||||||
|
expectedStatusCode: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UserInQuerystringGroup",
|
||||||
|
allowedGroups: []string{"a", "b"},
|
||||||
|
groups: []string{"a", "c"},
|
||||||
|
querystring: "?allowed_groups=a",
|
||||||
|
expectedStatusCode: http.StatusAccepted,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UserInMultiParamQuerystringGroup",
|
||||||
|
allowedGroups: []string{"a", "b"},
|
||||||
|
groups: []string{"b"},
|
||||||
|
querystring: "?allowed_groups=a&allowed_groups=b,d",
|
||||||
|
expectedStatusCode: http.StatusAccepted,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UserInOnlyQuerystringGroup",
|
||||||
|
allowedGroups: []string{},
|
||||||
|
groups: []string{"a", "c"},
|
||||||
|
querystring: "?allowed_groups=a,b",
|
||||||
|
expectedStatusCode: http.StatusAccepted,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UserInDelimitedQuerystringGroup",
|
||||||
|
allowedGroups: []string{"a", "b", "c"},
|
||||||
|
groups: []string{"c"},
|
||||||
|
querystring: "?allowed_groups=a,c",
|
||||||
|
expectedStatusCode: http.StatusAccepted,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UserNotInQuerystringGroup",
|
||||||
|
allowedGroups: []string{},
|
||||||
|
groups: []string{"c"},
|
||||||
|
querystring: "?allowed_groups=a,b",
|
||||||
|
expectedStatusCode: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UserInConfigGroupNotInQuerystringGroup",
|
||||||
|
allowedGroups: []string{"a", "b", "c"},
|
||||||
|
groups: []string{"c"},
|
||||||
|
querystring: "?allowed_groups=a,b",
|
||||||
|
expectedStatusCode: http.StatusForbidden,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UserInQuerystringGroupNotInConfigGroup",
|
||||||
|
allowedGroups: []string{"a", "b"},
|
||||||
|
groups: []string{"c"},
|
||||||
|
querystring: "?allowed_groups=b,c",
|
||||||
|
expectedStatusCode: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tc := range testCases {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
emailAddress := "test"
|
emailAddress := "test"
|
||||||
created := time.Now()
|
created := time.Now()
|
||||||
|
|
||||||
session := &sessions.SessionState{
|
session := &sessions.SessionState{
|
||||||
Groups: tt.groups,
|
Groups: tc.groups,
|
||||||
Email: emailAddress,
|
Email: emailAddress,
|
||||||
AccessToken: "oauth_token",
|
AccessToken: "oauth_token",
|
||||||
CreatedAt: &created,
|
CreatedAt: &created,
|
||||||
}
|
}
|
||||||
|
|
||||||
test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) {
|
test, err := NewAuthOnlyEndpointTest(tc.querystring, func(opts *options.Options) {
|
||||||
opts.AllowedGroups = tt.allowedGroups
|
opts.AllowedGroups = tc.allowedGroups
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|
@ -2728,11 +2929,7 @@ func TestAuthOnlyAllowedGroups(t *testing.T) {
|
||||||
|
|
||||||
test.proxy.ServeHTTP(test.rw, test.req)
|
test.proxy.ServeHTTP(test.rw, test.req)
|
||||||
|
|
||||||
if tt.expectUnauthorized {
|
assert.Equal(t, tc.expectedStatusCode, test.rw.Code)
|
||||||
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, http.StatusAccepted, test.rw.Code)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestMiddlewareSuite and related tests are in a *_test package
|
||||||
|
// to prevent circular imports with the `logger` package which uses
|
||||||
|
// this functionality
|
||||||
|
func TestMiddlewareSuite(t *testing.T) {
|
||||||
|
logger.SetOutput(GinkgoWriter)
|
||||||
|
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "Middleware API")
|
||||||
|
}
|
||||||
|
|
@ -1,13 +1,26 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type scopeKey string
|
||||||
|
|
||||||
|
// RequestScopeKey uses a typed string to reduce likelihood of clashing
|
||||||
|
// with other context keys
|
||||||
|
const RequestScopeKey scopeKey = "request-scope"
|
||||||
|
|
||||||
// RequestScope contains information regarding the request that is being made.
|
// RequestScope contains information regarding the request that is being made.
|
||||||
// The RequestScope is used to pass information between different middlewares
|
// The RequestScope is used to pass information between different middlewares
|
||||||
// within the chain.
|
// within the chain.
|
||||||
type RequestScope struct {
|
type RequestScope struct {
|
||||||
|
// ReverseProxy tracks whether OAuth2-Proxy is operating in reverse proxy
|
||||||
|
// mode and if request `X-Forwarded-*` headers should be trusted
|
||||||
|
ReverseProxy bool
|
||||||
|
|
||||||
// Session details the authenticated users information (if it exists).
|
// Session details the authenticated users information (if it exists).
|
||||||
Session *sessions.SessionState
|
Session *sessions.SessionState
|
||||||
|
|
||||||
|
|
@ -22,3 +35,19 @@ type RequestScope struct {
|
||||||
// it was loaded or not.
|
// it was loaded or not.
|
||||||
SessionRevalidated bool
|
SessionRevalidated bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetRequestScope returns the current request scope from the given request
|
||||||
|
func GetRequestScope(req *http.Request) *RequestScope {
|
||||||
|
scope := req.Context().Value(RequestScopeKey)
|
||||||
|
if scope == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return scope.(*RequestScope)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRequestScope adds a RequestScope to a request
|
||||||
|
func AddRequestScope(req *http.Request, scope *RequestScope) *http.Request {
|
||||||
|
ctx := context.WithValue(req.Context(), RequestScopeKey, scope)
|
||||||
|
return req.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,56 @@
|
||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Scope Suite", func() {
|
||||||
|
Context("GetRequestScope", func() {
|
||||||
|
var request *http.Request
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
var err error
|
||||||
|
request, err = http.NewRequest("", "http://127.0.0.1/", nil)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("with a scope", func() {
|
||||||
|
var scope *middleware.RequestScope
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
scope = &middleware.RequestScope{}
|
||||||
|
request = middleware.AddRequestScope(request, scope)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns the scope", func() {
|
||||||
|
s := middleware.GetRequestScope(request)
|
||||||
|
Expect(s).ToNot(BeNil())
|
||||||
|
Expect(s).To(Equal(scope))
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("if the scope is then modified", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
Expect(scope.SaveSession).To(BeFalse())
|
||||||
|
scope.SaveSession = true
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns the updated session", func() {
|
||||||
|
s := middleware.GetRequestScope(request)
|
||||||
|
Expect(s).ToNot(BeNil())
|
||||||
|
Expect(s).To(Equal(scope))
|
||||||
|
Expect(s.SaveSession).To(BeTrue())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("without a scope", func() {
|
||||||
|
It("returns nil", func() {
|
||||||
|
Expect(middleware.GetRequestScope(request)).To(BeNil())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
@ -36,7 +36,7 @@ type Options struct {
|
||||||
TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file"`
|
TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file"`
|
||||||
|
|
||||||
AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"`
|
AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"`
|
||||||
KeycloakGroup string `flag:"keycloak-group" cfg:"keycloak_group"`
|
KeycloakGroups []string `flag:"keycloak-group" cfg:"keycloak_groups"`
|
||||||
AzureTenant string `flag:"azure-tenant" cfg:"azure_tenant"`
|
AzureTenant string `flag:"azure-tenant" cfg:"azure_tenant"`
|
||||||
BitbucketTeam string `flag:"bitbucket-team" cfg:"bitbucket_team"`
|
BitbucketTeam string `flag:"bitbucket-team" cfg:"bitbucket_team"`
|
||||||
BitbucketRepository string `flag:"bitbucket-repository" cfg:"bitbucket_repository"`
|
BitbucketRepository string `flag:"bitbucket-repository" cfg:"bitbucket_repository"`
|
||||||
|
|
@ -87,6 +87,7 @@ type Options struct {
|
||||||
InsecureOIDCSkipIssuerVerification bool `flag:"insecure-oidc-skip-issuer-verification" cfg:"insecure_oidc_skip_issuer_verification"`
|
InsecureOIDCSkipIssuerVerification bool `flag:"insecure-oidc-skip-issuer-verification" cfg:"insecure_oidc_skip_issuer_verification"`
|
||||||
SkipOIDCDiscovery bool `flag:"skip-oidc-discovery" cfg:"skip_oidc_discovery"`
|
SkipOIDCDiscovery bool `flag:"skip-oidc-discovery" cfg:"skip_oidc_discovery"`
|
||||||
OIDCJwksURL string `flag:"oidc-jwks-url" cfg:"oidc_jwks_url"`
|
OIDCJwksURL string `flag:"oidc-jwks-url" cfg:"oidc_jwks_url"`
|
||||||
|
OIDCEmailClaim string `flag:"oidc-email-claim" cfg:"oidc_email_claim"`
|
||||||
OIDCGroupsClaim string `flag:"oidc-groups-claim" cfg:"oidc_groups_claim"`
|
OIDCGroupsClaim string `flag:"oidc-groups-claim" cfg:"oidc_groups_claim"`
|
||||||
LoginURL string `flag:"login-url" cfg:"login_url"`
|
LoginURL string `flag:"login-url" cfg:"login_url"`
|
||||||
RedeemURL string `flag:"redeem-url" cfg:"redeem_url"`
|
RedeemURL string `flag:"redeem-url" cfg:"redeem_url"`
|
||||||
|
|
@ -148,11 +149,12 @@ func NewOptions() *Options {
|
||||||
SkipAuthPreflight: false,
|
SkipAuthPreflight: false,
|
||||||
Prompt: "", // Change to "login" when ApprovalPrompt officially deprecated
|
Prompt: "", // Change to "login" when ApprovalPrompt officially deprecated
|
||||||
ApprovalPrompt: "force",
|
ApprovalPrompt: "force",
|
||||||
UserIDClaim: "email",
|
|
||||||
InsecureOIDCAllowUnverifiedEmail: false,
|
InsecureOIDCAllowUnverifiedEmail: false,
|
||||||
SkipOIDCDiscovery: false,
|
SkipOIDCDiscovery: false,
|
||||||
Logging: loggingDefaults(),
|
Logging: loggingDefaults(),
|
||||||
OIDCGroupsClaim: "groups",
|
UserIDClaim: providers.OIDCEmailClaim, // Deprecated: Use OIDCEmailClaim
|
||||||
|
OIDCEmailClaim: providers.OIDCEmailClaim,
|
||||||
|
OIDCGroupsClaim: providers.OIDCGroupsClaim,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -179,7 +181,7 @@ func NewFlagSet() *pflag.FlagSet {
|
||||||
|
|
||||||
flagSet.StringSlice("email-domain", []string{}, "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email")
|
flagSet.StringSlice("email-domain", []string{}, "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email")
|
||||||
flagSet.StringSlice("whitelist-domain", []string{}, "allowed domains for redirection after authentication. Prefix domain with a . to allow subdomains (eg .example.com)")
|
flagSet.StringSlice("whitelist-domain", []string{}, "allowed domains for redirection after authentication. Prefix domain with a . to allow subdomains (eg .example.com)")
|
||||||
flagSet.String("keycloak-group", "", "restrict login to members of this group.")
|
flagSet.StringSlice("keycloak-group", []string{}, "restrict logins to members of these groups (may be given multiple times)")
|
||||||
flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.")
|
flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.")
|
||||||
flagSet.String("bitbucket-team", "", "restrict logins to members of this team")
|
flagSet.String("bitbucket-team", "", "restrict logins to members of this team")
|
||||||
flagSet.String("bitbucket-repository", "", "restrict logins to user with access to this repository")
|
flagSet.String("bitbucket-repository", "", "restrict logins to user with access to this repository")
|
||||||
|
|
@ -226,7 +228,8 @@ func NewFlagSet() *pflag.FlagSet {
|
||||||
flagSet.Bool("insecure-oidc-skip-issuer-verification", false, "Do not verify if issuer matches OIDC discovery URL")
|
flagSet.Bool("insecure-oidc-skip-issuer-verification", false, "Do not verify if issuer matches OIDC discovery URL")
|
||||||
flagSet.Bool("skip-oidc-discovery", false, "Skip OIDC discovery and use manually supplied Endpoints")
|
flagSet.Bool("skip-oidc-discovery", false, "Skip OIDC discovery and use manually supplied Endpoints")
|
||||||
flagSet.String("oidc-jwks-url", "", "OpenID Connect JWKS URL (ie: https://www.googleapis.com/oauth2/v3/certs)")
|
flagSet.String("oidc-jwks-url", "", "OpenID Connect JWKS URL (ie: https://www.googleapis.com/oauth2/v3/certs)")
|
||||||
flagSet.String("oidc-groups-claim", "groups", "which claim contains the user groups")
|
flagSet.String("oidc-groups-claim", providers.OIDCGroupsClaim, "which OIDC claim contains the user groups")
|
||||||
|
flagSet.String("oidc-email-claim", providers.OIDCEmailClaim, "which OIDC claim contains the user's email")
|
||||||
flagSet.String("login-url", "", "Authentication endpoint")
|
flagSet.String("login-url", "", "Authentication endpoint")
|
||||||
flagSet.String("redeem-url", "", "Token redemption endpoint")
|
flagSet.String("redeem-url", "", "Token redemption endpoint")
|
||||||
flagSet.String("profile-url", "", "Profile access endpoint")
|
flagSet.String("profile-url", "", "Profile access endpoint")
|
||||||
|
|
@ -243,7 +246,7 @@ func NewFlagSet() *pflag.FlagSet {
|
||||||
flagSet.String("pubjwk-url", "", "JWK pubkey access endpoint: required by login.gov")
|
flagSet.String("pubjwk-url", "", "JWK pubkey access endpoint: required by login.gov")
|
||||||
flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints")
|
flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints")
|
||||||
|
|
||||||
flagSet.String("user-id-claim", "email", "which claim contains the user ID")
|
flagSet.String("user-id-claim", providers.OIDCEmailClaim, "(DEPRECATED for `oidc-email-claim`) which claim contains the user ID")
|
||||||
flagSet.StringSlice("allowed-group", []string{}, "restrict logins to members of this group (may be given multiple times)")
|
flagSet.StringSlice("allowed-group", []string{}, "restrict logins to members of this group (may be given multiple times)")
|
||||||
|
|
||||||
flagSet.AddFlagSet(cookieFlagSet())
|
flagSet.AddFlagSet(cookieFlagSet())
|
||||||
|
|
|
||||||
|
|
@ -9,14 +9,14 @@ import (
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MakeCookie constructs a cookie from the given parameters,
|
// MakeCookie constructs a cookie from the given parameters,
|
||||||
// discovering the domain from the request if not specified.
|
// discovering the domain from the request if not specified.
|
||||||
func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie {
|
func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie {
|
||||||
if domain != "" {
|
if domain != "" {
|
||||||
host := util.GetRequestHost(req)
|
host := requestutil.GetRequestHost(req)
|
||||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||||
host = h
|
host = h
|
||||||
}
|
}
|
||||||
|
|
@ -48,7 +48,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO
|
||||||
// If nothing matches, create the cookie with the shortest domain
|
// If nothing matches, create the cookie with the shortest domain
|
||||||
defaultDomain := ""
|
defaultDomain := ""
|
||||||
if len(cookieOpts.Domains) > 0 {
|
if len(cookieOpts.Domains) > 0 {
|
||||||
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", util.GetRequestHost(req), strings.Join(cookieOpts.Domains, ","))
|
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", requestutil.GetRequestHost(req), strings.Join(cookieOpts.Domains, ","))
|
||||||
defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1]
|
defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1]
|
||||||
}
|
}
|
||||||
return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite))
|
return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite))
|
||||||
|
|
@ -57,7 +57,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO
|
||||||
// GetCookieDomain returns the correct cookie domain given a list of domains
|
// GetCookieDomain returns the correct cookie domain given a list of domains
|
||||||
// by checking the X-Fowarded-Host and host header of an an http request
|
// by checking the X-Fowarded-Host and host header of an an http request
|
||||||
func GetCookieDomain(req *http.Request, cookieDomains []string) string {
|
func GetCookieDomain(req *http.Request, cookieDomains []string) string {
|
||||||
host := util.GetRequestHost(req)
|
host := requestutil.GetRequestHost(req)
|
||||||
for _, domain := range cookieDomains {
|
for _, domain := range cookieDomains {
|
||||||
if strings.HasSuffix(host, domain) {
|
if strings.HasSuffix(host, domain) {
|
||||||
return domain
|
return domain
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import (
|
||||||
"text/template"
|
"text/template"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AuthStatus defines the different types of auth logging that occur
|
// AuthStatus defines the different types of auth logging that occur
|
||||||
|
|
@ -197,7 +197,7 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu
|
||||||
|
|
||||||
err := l.authTemplate.Execute(l.writer, authLogMessageData{
|
err := l.authTemplate.Execute(l.writer, authLogMessageData{
|
||||||
Client: client,
|
Client: client,
|
||||||
Host: util.GetRequestHost(req),
|
Host: requestutil.GetRequestHost(req),
|
||||||
Protocol: req.Proto,
|
Protocol: req.Proto,
|
||||||
RequestMethod: req.Method,
|
RequestMethod: req.Method,
|
||||||
Timestamp: FormatTimestamp(now),
|
Timestamp: FormatTimestamp(now),
|
||||||
|
|
@ -251,7 +251,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url.
|
||||||
|
|
||||||
err := l.reqTemplate.Execute(l.writer, reqLogMessageData{
|
err := l.reqTemplate.Execute(l.writer, reqLogMessageData{
|
||||||
Client: client,
|
Client: client,
|
||||||
Host: util.GetRequestHost(req),
|
Host: requestutil.GetRequestHost(req),
|
||||||
Protocol: req.Proto,
|
Protocol: req.Proto,
|
||||||
RequestDuration: fmt.Sprintf("%0.3f", duration),
|
RequestDuration: fmt.Sprintf("%0.3f", duration),
|
||||||
RequestMethod: req.Method,
|
RequestMethod: req.Method,
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/justinas/alice"
|
"github.com/justinas/alice"
|
||||||
|
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||||
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
|
|
@ -23,7 +24,7 @@ func NewBasicAuthSessionLoader(validator basic.Validator) alice.Constructor {
|
||||||
// If a session was loaded by a previous handler, it will not be replaced.
|
// If a session was loaded by a previous handler, it will not be replaced.
|
||||||
func loadBasicAuthSession(validator basic.Validator, next http.Handler) http.Handler {
|
func loadBasicAuthSession(validator basic.Validator, next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
scope := GetRequestScope(req)
|
scope := middlewareapi.GetRequestScope(req)
|
||||||
// If scope is nil, this will panic.
|
// If scope is nil, this will panic.
|
||||||
// A scope should always be injected before this handler is called.
|
// A scope should always be injected before this handler is called.
|
||||||
if scope.Session != nil {
|
if scope.Session != nil {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
|
@ -40,8 +39,7 @@ var _ = Describe("Basic Auth Session Suite", func() {
|
||||||
// Set up the request with the authorization header and a request scope
|
// Set up the request with the authorization header and a request scope
|
||||||
req := httptest.NewRequest("", "/", nil)
|
req := httptest.NewRequest("", "/", nil)
|
||||||
req.Header.Set("Authorization", in.authorizationHeader)
|
req.Header.Set("Authorization", in.authorizationHeader)
|
||||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
req = middlewareapi.AddRequestScope(req, scope)
|
||||||
req = req.WithContext(contextWithScope)
|
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
|
@ -57,7 +55,7 @@ var _ = Describe("Basic Auth Session Suite", func() {
|
||||||
// from the scope
|
// from the scope
|
||||||
var gotSession *sessionsapi.SessionState
|
var gotSession *sessionsapi.SessionState
|
||||||
handler := NewBasicAuthSessionLoader(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := NewBasicAuthSessionLoader(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session
|
gotSession = middlewareapi.GetRequestScope(r).Session
|
||||||
}))
|
}))
|
||||||
handler.ServeHTTP(rw, req)
|
handler.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/justinas/alice"
|
"github.com/justinas/alice"
|
||||||
|
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/header"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/header"
|
||||||
)
|
)
|
||||||
|
|
@ -61,7 +62,7 @@ func newRequestHeaderInjector(headers []options.Header) (alice.Constructor, erro
|
||||||
|
|
||||||
func injectRequestHeaders(injector header.Injector, next http.Handler) http.Handler {
|
func injectRequestHeaders(injector header.Injector, next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
scope := GetRequestScope(req)
|
scope := middlewareapi.GetRequestScope(req)
|
||||||
|
|
||||||
// If scope is nil, this will panic.
|
// If scope is nil, this will panic.
|
||||||
// A scope should always be injected before this handler is called.
|
// A scope should always be injected before this handler is called.
|
||||||
|
|
@ -92,7 +93,7 @@ func newResponseHeaderInjector(headers []options.Header) (alice.Constructor, err
|
||||||
|
|
||||||
func injectResponseHeaders(injector header.Injector, next http.Handler) http.Handler {
|
func injectResponseHeaders(injector header.Injector, next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
scope := GetRequestScope(req)
|
scope := middlewareapi.GetRequestScope(req)
|
||||||
|
|
||||||
// If scope is nil, this will panic.
|
// If scope is nil, this will panic.
|
||||||
// A scope should always be injected before this handler is called.
|
// A scope should always be injected before this handler is called.
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
|
@ -31,8 +30,7 @@ var _ = Describe("Headers Suite", func() {
|
||||||
|
|
||||||
// Set up the request with a request scope
|
// Set up the request with a request scope
|
||||||
req := httptest.NewRequest("", "/", nil)
|
req := httptest.NewRequest("", "/", nil)
|
||||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
req = middlewareapi.AddRequestScope(req, scope)
|
||||||
req = req.WithContext(contextWithScope)
|
|
||||||
req.Header = in.initialHeaders.Clone()
|
req.Header = in.initialHeaders.Clone()
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
|
@ -218,8 +216,7 @@ var _ = Describe("Headers Suite", func() {
|
||||||
|
|
||||||
// Set up the request with a request scope
|
// Set up the request with a request scope
|
||||||
req := httptest.NewRequest("", "/", nil)
|
req := httptest.NewRequest("", "/", nil)
|
||||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
req = middlewareapi.AddRequestScope(req, scope)
|
||||||
req = req.WithContext(contextWithScope)
|
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
for key, values := range in.initialHeaders {
|
for key, values := range in.initialHeaders {
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ type jwtSessionLoader struct {
|
||||||
// If a session was loaded by a previous handler, it will not be replaced.
|
// If a session was loaded by a previous handler, it will not be replaced.
|
||||||
func (j *jwtSessionLoader) loadSession(next http.Handler) http.Handler {
|
func (j *jwtSessionLoader) loadSession(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
scope := GetRequestScope(req)
|
scope := middlewareapi.GetRequestScope(req)
|
||||||
// If scope is nil, this will panic.
|
// If scope is nil, this will panic.
|
||||||
// A scope should always be injected before this handler is called.
|
// A scope should always be injected before this handler is called.
|
||||||
if scope.Session != nil {
|
if scope.Session != nil {
|
||||||
|
|
|
||||||
|
|
@ -103,8 +103,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
|
||||||
// Set up the request with the authorization header and a request scope
|
// Set up the request with the authorization header and a request scope
|
||||||
req := httptest.NewRequest("", "/", nil)
|
req := httptest.NewRequest("", "/", nil)
|
||||||
req.Header.Set("Authorization", in.authorizationHeader)
|
req.Header.Set("Authorization", in.authorizationHeader)
|
||||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
req = middlewareapi.AddRequestScope(req, scope)
|
||||||
req = req.WithContext(contextWithScope)
|
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
|
@ -116,7 +115,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
|
||||||
// from the scope
|
// from the scope
|
||||||
var gotSession *sessionsapi.SessionState
|
var gotSession *sessionsapi.SessionState
|
||||||
handler := NewJwtSessionLoader(sessionLoaders)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := NewJwtSessionLoader(sessionLoaders)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session
|
gotSession = middlewareapi.GetRequestScope(r).Session
|
||||||
}))
|
}))
|
||||||
handler.ServeHTTP(rw, req)
|
handler.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/justinas/alice"
|
"github.com/justinas/alice"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
const httpsScheme = "https"
|
const httpsScheme = "https"
|
||||||
|
|
@ -26,10 +26,11 @@ func NewRedirectToHTTPS(httpsPort string) alice.Constructor {
|
||||||
// to the port from the httpsAddress given.
|
// to the port from the httpsAddress given.
|
||||||
func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler {
|
func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
proto := req.Header.Get("X-Forwarded-Proto")
|
proto := requestutil.GetRequestProto(req)
|
||||||
if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == "") {
|
if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == req.URL.Scheme) {
|
||||||
// Only care about the connection to us being HTTPS if the proto is empty,
|
// Only care about the connection to us being HTTPS if the proto wasn't
|
||||||
// otherwise the proto is source of truth
|
// from a trusted `X-Forwarded-Proto` (proto == req.URL.Scheme).
|
||||||
|
// Otherwise the proto is source of truth
|
||||||
next.ServeHTTP(rw, req)
|
next.ServeHTTP(rw, req)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -41,7 +42,7 @@ func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler {
|
||||||
|
|
||||||
// Set the Host in case the targetURL still does not have one
|
// Set the Host in case the targetURL still does not have one
|
||||||
// or it isn't X-Forwarded-Host aware
|
// or it isn't X-Forwarded-Host aware
|
||||||
targetURL.Host = util.GetRequestHost(req)
|
targetURL.Host = requestutil.GetRequestHost(req)
|
||||||
|
|
||||||
// Overwrite the port if the original request was to a non-standard port
|
// Overwrite the port if the original request was to a non-standard port
|
||||||
if targetURL.Port() != "" {
|
if targetURL.Port() != "" {
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
|
||||||
|
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/ginkgo/extensions/table"
|
. "github.com/onsi/ginkgo/extensions/table"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
|
|
@ -21,6 +22,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||||
requestString string
|
requestString string
|
||||||
useTLS bool
|
useTLS bool
|
||||||
headers map[string]string
|
headers map[string]string
|
||||||
|
reverseProxy bool
|
||||||
expectedStatus int
|
expectedStatus int
|
||||||
expectedBody string
|
expectedBody string
|
||||||
expectedLocation string
|
expectedLocation string
|
||||||
|
|
@ -35,6 +37,10 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||||
if in.useTLS {
|
if in.useTLS {
|
||||||
req.TLS = &tls.ConnectionState{}
|
req.TLS = &tls.ConnectionState{}
|
||||||
}
|
}
|
||||||
|
scope := &middlewareapi.RequestScope{
|
||||||
|
ReverseProxy: in.reverseProxy,
|
||||||
|
}
|
||||||
|
req = middlewareapi.AddRequestScope(req, scope)
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
|
@ -52,6 +58,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||||
requestString: "http://example.com",
|
requestString: "http://example.com",
|
||||||
useTLS: false,
|
useTLS: false,
|
||||||
headers: map[string]string{},
|
headers: map[string]string{},
|
||||||
|
reverseProxy: false,
|
||||||
expectedStatus: 308,
|
expectedStatus: 308,
|
||||||
expectedBody: permanentRedirectBody("https://example.com"),
|
expectedBody: permanentRedirectBody("https://example.com"),
|
||||||
expectedLocation: "https://example.com",
|
expectedLocation: "https://example.com",
|
||||||
|
|
@ -60,6 +67,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||||
requestString: "https://example.com",
|
requestString: "https://example.com",
|
||||||
useTLS: true,
|
useTLS: true,
|
||||||
headers: map[string]string{},
|
headers: map[string]string{},
|
||||||
|
reverseProxy: false,
|
||||||
expectedStatus: 200,
|
expectedStatus: 200,
|
||||||
expectedBody: "test",
|
expectedBody: "test",
|
||||||
}),
|
}),
|
||||||
|
|
@ -69,15 +77,28 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||||
headers: map[string]string{
|
headers: map[string]string{
|
||||||
"X-Forwarded-Proto": "HTTPS",
|
"X-Forwarded-Proto": "HTTPS",
|
||||||
},
|
},
|
||||||
|
reverseProxy: true,
|
||||||
expectedStatus: 200,
|
expectedStatus: 200,
|
||||||
expectedBody: "test",
|
expectedBody: "test",
|
||||||
}),
|
}),
|
||||||
|
Entry("without TLS and X-Forwarded-Proto=HTTPS but ReverseProxy not set", &requestTableInput{
|
||||||
|
requestString: "http://example.com",
|
||||||
|
useTLS: false,
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Forwarded-Proto": "HTTPS",
|
||||||
|
},
|
||||||
|
reverseProxy: false,
|
||||||
|
expectedStatus: 308,
|
||||||
|
expectedBody: permanentRedirectBody("https://example.com"),
|
||||||
|
expectedLocation: "https://example.com",
|
||||||
|
}),
|
||||||
Entry("with TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{
|
Entry("with TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{
|
||||||
requestString: "https://example.com",
|
requestString: "https://example.com",
|
||||||
useTLS: true,
|
useTLS: true,
|
||||||
headers: map[string]string{
|
headers: map[string]string{
|
||||||
"X-Forwarded-Proto": "HTTPS",
|
"X-Forwarded-Proto": "HTTPS",
|
||||||
},
|
},
|
||||||
|
reverseProxy: true,
|
||||||
expectedStatus: 200,
|
expectedStatus: 200,
|
||||||
expectedBody: "test",
|
expectedBody: "test",
|
||||||
}),
|
}),
|
||||||
|
|
@ -87,6 +108,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||||
headers: map[string]string{
|
headers: map[string]string{
|
||||||
"X-Forwarded-Proto": "https",
|
"X-Forwarded-Proto": "https",
|
||||||
},
|
},
|
||||||
|
reverseProxy: true,
|
||||||
expectedStatus: 200,
|
expectedStatus: 200,
|
||||||
expectedBody: "test",
|
expectedBody: "test",
|
||||||
}),
|
}),
|
||||||
|
|
@ -96,6 +118,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||||
headers: map[string]string{
|
headers: map[string]string{
|
||||||
"X-Forwarded-Proto": "https",
|
"X-Forwarded-Proto": "https",
|
||||||
},
|
},
|
||||||
|
reverseProxy: true,
|
||||||
expectedStatus: 200,
|
expectedStatus: 200,
|
||||||
expectedBody: "test",
|
expectedBody: "test",
|
||||||
}),
|
}),
|
||||||
|
|
@ -105,6 +128,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||||
headers: map[string]string{
|
headers: map[string]string{
|
||||||
"X-Forwarded-Proto": "HTTP",
|
"X-Forwarded-Proto": "HTTP",
|
||||||
},
|
},
|
||||||
|
reverseProxy: true,
|
||||||
expectedStatus: 308,
|
expectedStatus: 308,
|
||||||
expectedBody: permanentRedirectBody("https://example.com"),
|
expectedBody: permanentRedirectBody("https://example.com"),
|
||||||
expectedLocation: "https://example.com",
|
expectedLocation: "https://example.com",
|
||||||
|
|
@ -115,6 +139,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||||
headers: map[string]string{
|
headers: map[string]string{
|
||||||
"X-Forwarded-Proto": "HTTP",
|
"X-Forwarded-Proto": "HTTP",
|
||||||
},
|
},
|
||||||
|
reverseProxy: true,
|
||||||
expectedStatus: 308,
|
expectedStatus: 308,
|
||||||
expectedBody: permanentRedirectBody("https://example.com"),
|
expectedBody: permanentRedirectBody("https://example.com"),
|
||||||
expectedLocation: "https://example.com",
|
expectedLocation: "https://example.com",
|
||||||
|
|
@ -125,6 +150,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||||
headers: map[string]string{
|
headers: map[string]string{
|
||||||
"X-Forwarded-Proto": "http",
|
"X-Forwarded-Proto": "http",
|
||||||
},
|
},
|
||||||
|
reverseProxy: true,
|
||||||
expectedStatus: 308,
|
expectedStatus: 308,
|
||||||
expectedBody: permanentRedirectBody("https://example.com"),
|
expectedBody: permanentRedirectBody("https://example.com"),
|
||||||
expectedLocation: "https://example.com",
|
expectedLocation: "https://example.com",
|
||||||
|
|
@ -135,6 +161,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||||
headers: map[string]string{
|
headers: map[string]string{
|
||||||
"X-Forwarded-Proto": "http",
|
"X-Forwarded-Proto": "http",
|
||||||
},
|
},
|
||||||
|
reverseProxy: true,
|
||||||
expectedStatus: 308,
|
expectedStatus: 308,
|
||||||
expectedBody: permanentRedirectBody("https://example.com"),
|
expectedBody: permanentRedirectBody("https://example.com"),
|
||||||
expectedLocation: "https://example.com",
|
expectedLocation: "https://example.com",
|
||||||
|
|
@ -143,6 +170,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||||
requestString: "http://example.com:8080",
|
requestString: "http://example.com:8080",
|
||||||
useTLS: false,
|
useTLS: false,
|
||||||
headers: map[string]string{},
|
headers: map[string]string{},
|
||||||
|
reverseProxy: false,
|
||||||
expectedStatus: 308,
|
expectedStatus: 308,
|
||||||
expectedBody: permanentRedirectBody("https://example.com:8443"),
|
expectedBody: permanentRedirectBody("https://example.com:8443"),
|
||||||
expectedLocation: "https://example.com:8443",
|
expectedLocation: "https://example.com:8443",
|
||||||
|
|
@ -151,6 +179,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||||
requestString: "https://example.com:8443",
|
requestString: "https://example.com:8443",
|
||||||
useTLS: true,
|
useTLS: true,
|
||||||
headers: map[string]string{},
|
headers: map[string]string{},
|
||||||
|
reverseProxy: false,
|
||||||
expectedStatus: 200,
|
expectedStatus: 200,
|
||||||
expectedBody: "test",
|
expectedBody: "test",
|
||||||
}),
|
}),
|
||||||
|
|
@ -161,6 +190,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||||
requestString: "/",
|
requestString: "/",
|
||||||
useTLS: false,
|
useTLS: false,
|
||||||
expectedStatus: 308,
|
expectedStatus: 308,
|
||||||
|
reverseProxy: false,
|
||||||
expectedBody: permanentRedirectBody("https://example.com/"),
|
expectedBody: permanentRedirectBody("https://example.com/"),
|
||||||
expectedLocation: "https://example.com/",
|
expectedLocation: "https://example.com/",
|
||||||
}),
|
}),
|
||||||
|
|
@ -171,6 +201,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
||||||
"X-Forwarded-Proto": "HTTP",
|
"X-Forwarded-Proto": "HTTP",
|
||||||
"X-Forwarded-Host": "external.example.com",
|
"X-Forwarded-Host": "external.example.com",
|
||||||
},
|
},
|
||||||
|
reverseProxy: true,
|
||||||
expectedStatus: 308,
|
expectedStatus: 308,
|
||||||
expectedBody: permanentRedirectBody("https://external.example.com"),
|
expectedBody: permanentRedirectBody("https://external.example.com"),
|
||||||
expectedLocation: "https://external.example.com",
|
expectedLocation: "https://external.example.com",
|
||||||
|
|
|
||||||
|
|
@ -1,39 +1,20 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/justinas/alice"
|
"github.com/justinas/alice"
|
||||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
type scopeKey string
|
func NewScope(reverseProxy bool) alice.Constructor {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
// requestScopeKey uses a typed string to reduce likelihood of clasing
|
|
||||||
// with other context keys
|
|
||||||
const requestScopeKey scopeKey = "request-scope"
|
|
||||||
|
|
||||||
func NewScope() alice.Constructor {
|
|
||||||
return addScope
|
|
||||||
}
|
|
||||||
|
|
||||||
// addScope injects a new request scope into the request context.
|
|
||||||
func addScope(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
scope := &middlewareapi.RequestScope{}
|
scope := &middlewareapi.RequestScope{
|
||||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
ReverseProxy: reverseProxy,
|
||||||
requestWithScope := req.WithContext(contextWithScope)
|
}
|
||||||
next.ServeHTTP(rw, requestWithScope)
|
req = middlewareapi.AddRequestScope(req, scope)
|
||||||
|
next.ServeHTTP(rw, req)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRequestScope returns the current request scope from the given request
|
|
||||||
func GetRequestScope(req *http.Request) *middlewareapi.RequestScope {
|
|
||||||
scope := req.Context().Value(requestScopeKey)
|
|
||||||
if scope == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return scope.(*middlewareapi.RequestScope)
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
|
||||||
|
|
@ -21,8 +20,11 @@ var _ = Describe("Scope Suite", func() {
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
rw = httptest.NewRecorder()
|
rw = httptest.NewRecorder()
|
||||||
|
})
|
||||||
|
|
||||||
handler := NewScope()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
Context("ReverseProxy is false", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
handler := NewScope(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
nextRequest = r
|
nextRequest = r
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
}))
|
}))
|
||||||
|
|
@ -30,64 +32,37 @@ var _ = Describe("Scope Suite", func() {
|
||||||
})
|
})
|
||||||
|
|
||||||
It("does not add a scope to the original request", func() {
|
It("does not add a scope to the original request", func() {
|
||||||
Expect(request.Context().Value(requestScopeKey)).To(BeNil())
|
Expect(request.Context().Value(middlewareapi.RequestScopeKey)).To(BeNil())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("cannot load a scope from the original request using GetRequestScope", func() {
|
It("cannot load a scope from the original request using GetRequestScope", func() {
|
||||||
Expect(GetRequestScope(request)).To(BeNil())
|
Expect(middlewareapi.GetRequestScope(request)).To(BeNil())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("adds a scope to the request for the next handler", func() {
|
It("adds a scope to the request for the next handler", func() {
|
||||||
Expect(nextRequest.Context().Value(requestScopeKey)).ToNot(BeNil())
|
Expect(nextRequest.Context().Value(middlewareapi.RequestScopeKey)).ToNot(BeNil())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("can load a scope from the next handler's request using GetRequestScope", func() {
|
It("can load a scope from the next handler's request using GetRequestScope", func() {
|
||||||
Expect(GetRequestScope(nextRequest)).ToNot(BeNil())
|
scope := middlewareapi.GetRequestScope(nextRequest)
|
||||||
|
Expect(scope).ToNot(BeNil())
|
||||||
|
Expect(scope.ReverseProxy).To(BeFalse())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("GetRequestScope", func() {
|
Context("ReverseProxy is true", func() {
|
||||||
var request *http.Request
|
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
var err error
|
handler := NewScope(true)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
request, err = http.NewRequest("", "http://127.0.0.1/", nil)
|
nextRequest = r
|
||||||
Expect(err).ToNot(HaveOccurred())
|
w.WriteHeader(200)
|
||||||
|
}))
|
||||||
|
handler.ServeHTTP(rw, request)
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("with a scope", func() {
|
It("return a scope where the ReverseProxy field is true", func() {
|
||||||
var scope *middlewareapi.RequestScope
|
scope := middlewareapi.GetRequestScope(nextRequest)
|
||||||
|
Expect(scope).ToNot(BeNil())
|
||||||
BeforeEach(func() {
|
Expect(scope.ReverseProxy).To(BeTrue())
|
||||||
scope = &middlewareapi.RequestScope{}
|
|
||||||
contextWithScope := context.WithValue(request.Context(), requestScopeKey, scope)
|
|
||||||
request = request.WithContext(contextWithScope)
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns the scope", func() {
|
|
||||||
s := GetRequestScope(request)
|
|
||||||
Expect(s).ToNot(BeNil())
|
|
||||||
Expect(s).To(Equal(scope))
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("if the scope is then modified", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
Expect(scope.SaveSession).To(BeFalse())
|
|
||||||
scope.SaveSession = true
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns the updated session", func() {
|
|
||||||
s := GetRequestScope(request)
|
|
||||||
Expect(s).ToNot(BeNil())
|
|
||||||
Expect(s).To(Equal(scope))
|
|
||||||
Expect(s.SaveSession).To(BeTrue())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Context("without a scope", func() {
|
|
||||||
It("returns nil", func() {
|
|
||||||
Expect(GetRequestScope(request)).To(BeNil())
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/justinas/alice"
|
"github.com/justinas/alice"
|
||||||
|
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||||
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
@ -59,7 +60,7 @@ type storedSessionLoader struct {
|
||||||
// If a session was loader by a previous handler, it will not be replaced.
|
// If a session was loader by a previous handler, it will not be replaced.
|
||||||
func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler {
|
func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
scope := GetRequestScope(req)
|
scope := middlewareapi.GetRequestScope(req)
|
||||||
// If scope is nil, this will panic.
|
// If scope is nil, this will panic.
|
||||||
// A scope should always be injected before this handler is called.
|
// A scope should always be injected before this handler is called.
|
||||||
if scope.Session != nil {
|
if scope.Session != nil {
|
||||||
|
|
|
||||||
|
|
@ -104,8 +104,7 @@ var _ = Describe("Stored Session Suite", func() {
|
||||||
// Set up the request with the request headesr and a request scope
|
// Set up the request with the request headesr and a request scope
|
||||||
req := httptest.NewRequest("", "/", nil)
|
req := httptest.NewRequest("", "/", nil)
|
||||||
req.Header = in.requestHeaders
|
req.Header = in.requestHeaders
|
||||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
req = middlewareapi.AddRequestScope(req, scope)
|
||||||
req = req.WithContext(contextWithScope)
|
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
|
|
||||||
|
|
@ -120,7 +119,7 @@ var _ = Describe("Stored Session Suite", func() {
|
||||||
// from the scope
|
// from the scope
|
||||||
var gotSession *sessionsapi.SessionState
|
var gotSession *sessionsapi.SessionState
|
||||||
handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session
|
gotSession = middlewareapi.GetRequestScope(r).Session
|
||||||
}))
|
}))
|
||||||
handler.ServeHTTP(rw, req)
|
handler.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetRequestProto returns the request scheme or X-Forwarded-Proto if present
|
||||||
|
// and the request is proxied.
|
||||||
|
func GetRequestProto(req *http.Request) string {
|
||||||
|
proto := req.Header.Get("X-Forwarded-Proto")
|
||||||
|
if !IsProxied(req) || proto == "" {
|
||||||
|
proto = req.URL.Scheme
|
||||||
|
}
|
||||||
|
return proto
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRequestHost returns the request host header or X-Forwarded-Host if
|
||||||
|
// present and the request is proxied.
|
||||||
|
func GetRequestHost(req *http.Request) string {
|
||||||
|
host := req.Header.Get("X-Forwarded-Host")
|
||||||
|
if !IsProxied(req) || host == "" {
|
||||||
|
host = req.Host
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRequestURI return the request URI or X-Forwarded-Uri if present and the
|
||||||
|
// request is proxied.
|
||||||
|
func GetRequestURI(req *http.Request) string {
|
||||||
|
uri := req.Header.Get("X-Forwarded-Uri")
|
||||||
|
if !IsProxied(req) || uri == "" {
|
||||||
|
// Use RequestURI to preserve ?query
|
||||||
|
uri = req.URL.RequestURI()
|
||||||
|
}
|
||||||
|
return uri
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsProxied determines if a request was from a proxy based on the RequestScope
|
||||||
|
// ReverseProxy tracker.
|
||||||
|
func IsProxied(req *http.Request) bool {
|
||||||
|
scope := middlewareapi.GetRequestScope(req)
|
||||||
|
if scope == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return scope.ReverseProxy
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
package util_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestRequestUtilSuite and related tests are in a *_test package
|
||||||
|
// to prevent circular imports with the `logger` package which uses
|
||||||
|
// this functionality
|
||||||
|
func TestRequestUtilSuite(t *testing.T) {
|
||||||
|
logger.SetOutput(GinkgoWriter)
|
||||||
|
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "Request Utils")
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,131 @@
|
||||||
|
package util_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Util Suite", func() {
|
||||||
|
const (
|
||||||
|
proto = "http"
|
||||||
|
host = "www.oauth2proxy.test"
|
||||||
|
uri = "/test/endpoint"
|
||||||
|
)
|
||||||
|
var req *http.Request
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
req = httptest.NewRequest(
|
||||||
|
http.MethodGet,
|
||||||
|
fmt.Sprintf("%s://%s%s", proto, host, uri),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("GetRequestHost", func() {
|
||||||
|
Context("IsProxied is false", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns the host", func() {
|
||||||
|
Expect(util.GetRequestHost(req)).To(Equal(host))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("ignores X-Forwarded-Host and returns the host", func() {
|
||||||
|
req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text")
|
||||||
|
Expect(util.GetRequestHost(req)).To(Equal(host))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("IsProxied is true", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
req = middleware.AddRequestScope(req, &middleware.RequestScope{
|
||||||
|
ReverseProxy: true,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns the host if X-Forwarded-Host is not present", func() {
|
||||||
|
Expect(util.GetRequestHost(req)).To(Equal(host))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns the X-Forwarded-Host when present", func() {
|
||||||
|
req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text")
|
||||||
|
Expect(util.GetRequestHost(req)).To(Equal("external.oauth2proxy.text"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("GetRequestProto", func() {
|
||||||
|
Context("IsProxied is false", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns the scheme", func() {
|
||||||
|
Expect(util.GetRequestProto(req)).To(Equal(proto))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("ignores X-Forwarded-Proto and returns the scheme", func() {
|
||||||
|
req.Header.Add("X-Forwarded-Proto", "https")
|
||||||
|
Expect(util.GetRequestProto(req)).To(Equal(proto))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("IsProxied is true", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
req = middleware.AddRequestScope(req, &middleware.RequestScope{
|
||||||
|
ReverseProxy: true,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns the scheme if X-Forwarded-Proto is not present", func() {
|
||||||
|
Expect(util.GetRequestProto(req)).To(Equal(proto))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns the X-Forwarded-Proto when present", func() {
|
||||||
|
req.Header.Add("X-Forwarded-Proto", "https")
|
||||||
|
Expect(util.GetRequestProto(req)).To(Equal("https"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("GetRequestURI", func() {
|
||||||
|
Context("IsProxied is false", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns the URI", func() {
|
||||||
|
Expect(util.GetRequestURI(req)).To(Equal(uri))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("ignores X-Forwarded-Uri and returns the URI", func() {
|
||||||
|
req.Header.Add("X-Forwarded-Uri", "/some/other/path")
|
||||||
|
Expect(util.GetRequestURI(req)).To(Equal(uri))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("IsProxied is true", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
req = middleware.AddRequestScope(req, &middleware.RequestScope{
|
||||||
|
ReverseProxy: true,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns the URI if X-Forwarded-Uri is not present", func() {
|
||||||
|
Expect(util.GetRequestURI(req)).To(Equal(uri))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns the X-Forwarded-Uri when present", func() {
|
||||||
|
req.Header.Add("X-Forwarded-Uri", "/some/other/path")
|
||||||
|
Expect(util.GetRequestURI(req)).To(Equal("/some/other/path"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||||
|
|
@ -220,12 +219,12 @@ func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) {
|
||||||
if len(cookies) == 0 {
|
if len(cookies) == 0 {
|
||||||
return nil, fmt.Errorf("could not find cookie %s", cookieName)
|
return nil, fmt.Errorf("could not find cookie %s", cookieName)
|
||||||
}
|
}
|
||||||
return joinCookies(cookies)
|
return joinCookies(cookies, cookieName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// joinCookies takes a slice of cookies from the request and reconstructs the
|
// joinCookies takes a slice of cookies from the request and reconstructs the
|
||||||
// full session cookie
|
// full session cookie
|
||||||
func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) {
|
func joinCookies(cookies []*http.Cookie, cookieName string) (*http.Cookie, error) {
|
||||||
if len(cookies) == 0 {
|
if len(cookies) == 0 {
|
||||||
return nil, fmt.Errorf("list of cookies must be > 0")
|
return nil, fmt.Errorf("list of cookies must be > 0")
|
||||||
}
|
}
|
||||||
|
|
@ -236,7 +235,7 @@ func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) {
|
||||||
for i := 1; i < len(cookies); i++ {
|
for i := 1; i < len(cookies); i++ {
|
||||||
c.Value += cookies[i].Value
|
c.Value += cookies[i].Value
|
||||||
}
|
}
|
||||||
c.Name = strings.TrimRight(c.Name, "_0")
|
c.Name = cookieName
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -154,9 +154,58 @@ func Test_splitCookie_joinCookies(t *testing.T) {
|
||||||
Value: value,
|
Value: value,
|
||||||
}
|
}
|
||||||
splitCookies := splitCookie(cookie)
|
splitCookies := splitCookie(cookie)
|
||||||
joinedCookie, err := joinCookies(splitCookies)
|
joinedCookie, err := joinCookies(splitCookies, cookie.Name)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, *cookie, *joinedCookie)
|
assert.Equal(t, *cookie, *joinedCookie)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_joinCookies_withUnderlineSuffix(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
CookieName string
|
||||||
|
SplitOrder []int
|
||||||
|
}{
|
||||||
|
"Ascending order split with \"_\" suffix": {
|
||||||
|
CookieName: "_cookie_name_",
|
||||||
|
SplitOrder: []int{0, 1, 2, 3, 4},
|
||||||
|
},
|
||||||
|
"Descending order split with \"_\" suffix": {
|
||||||
|
CookieName: "_cookie_name_",
|
||||||
|
SplitOrder: []int{4, 3, 2, 1, 0},
|
||||||
|
},
|
||||||
|
"Arbitrary order split with \"_\" suffix": {
|
||||||
|
CookieName: "_cookie_name_",
|
||||||
|
SplitOrder: []int{3, 1, 2, 0, 4},
|
||||||
|
},
|
||||||
|
"Arbitrary order split with \"_0\" suffix": {
|
||||||
|
CookieName: "_cookie_name_0",
|
||||||
|
SplitOrder: []int{1, 3, 0, 2, 4},
|
||||||
|
},
|
||||||
|
"Arbitrary order split with \"_1\" suffix": {
|
||||||
|
CookieName: "_cookie_name_1",
|
||||||
|
SplitOrder: []int{4, 1, 3, 0, 2},
|
||||||
|
},
|
||||||
|
"Arbitrary order split with \"__\" suffix": {
|
||||||
|
CookieName: "_cookie_name__",
|
||||||
|
SplitOrder: []int{1, 0, 4, 3, 2},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for testName, testCase := range testCases {
|
||||||
|
t.Run(testName, func(t *testing.T) {
|
||||||
|
cookieName := testCase.CookieName
|
||||||
|
var splitCookies []*http.Cookie
|
||||||
|
for _, splitSuffix := range testCase.SplitOrder {
|
||||||
|
cookie := &http.Cookie{
|
||||||
|
Name: splitCookieName(cookieName, splitSuffix),
|
||||||
|
Value: strings.Repeat("v", 1000),
|
||||||
|
}
|
||||||
|
splitCookies = append(splitCookies, cookie)
|
||||||
|
}
|
||||||
|
joinedCookie, err := joinCookies(splitCookies, cookieName)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, cookieName, joinedCookie.Name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetCertPool(paths []string) (*x509.CertPool, error) {
|
func GetCertPool(paths []string) (*x509.CertPool, error) {
|
||||||
|
|
@ -24,12 +23,3 @@ func GetCertPool(paths []string) (*x509.CertPool, error) {
|
||||||
}
|
}
|
||||||
return pool, nil
|
return pool, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRequestHost return the request host header or X-Forwarded-Host if present
|
|
||||||
func GetRequestHost(req *http.Request) string {
|
|
||||||
host := req.Header.Get("X-Forwarded-Host")
|
|
||||||
if host == "" {
|
|
||||||
host = req.Host
|
|
||||||
}
|
|
||||||
return host
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -4,11 +4,9 @@ import (
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
"encoding/asn1"
|
"encoding/asn1"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http/httptest"
|
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -97,16 +95,3 @@ func TestGetCertPool(t *testing.T) {
|
||||||
expectedSubjects := []string{testCA1Subj, testCA2Subj}
|
expectedSubjects := []string{testCA1Subj, testCA2Subj}
|
||||||
assert.Equal(t, expectedSubjects, got)
|
assert.Equal(t, expectedSubjects, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetRequestHost(t *testing.T) {
|
|
||||||
g := NewWithT(t)
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "https://example.com", nil)
|
|
||||||
host := GetRequestHost(req)
|
|
||||||
g.Expect(host).To(Equal("example.com"))
|
|
||||||
|
|
||||||
proxyReq := httptest.NewRequest("GET", "http://internal.example.com", nil)
|
|
||||||
proxyReq.Header.Add("X-Forwarded-Host", "external.example.com")
|
|
||||||
extHost := GetRequestHost(proxyReq)
|
|
||||||
g.Expect(extHost).To(Equal("external.example.com"))
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -233,9 +233,19 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
|
||||||
p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
|
p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
|
||||||
p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs)
|
p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs)
|
||||||
|
|
||||||
// Make the OIDC Verifier accessible to all providers that can support it
|
// Make the OIDC options available to all providers that support it
|
||||||
|
p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail
|
||||||
|
p.EmailClaim = o.OIDCEmailClaim
|
||||||
|
p.GroupsClaim = o.OIDCGroupsClaim
|
||||||
p.Verifier = o.GetOIDCVerifier()
|
p.Verifier = o.GetOIDCVerifier()
|
||||||
|
|
||||||
|
// TODO (@NickMeves) - Remove This
|
||||||
|
// Backwards Compatibility for Deprecated UserIDClaim option
|
||||||
|
if o.OIDCEmailClaim == providers.OIDCEmailClaim &&
|
||||||
|
o.UserIDClaim != providers.OIDCEmailClaim {
|
||||||
|
p.EmailClaim = o.UserIDClaim
|
||||||
|
}
|
||||||
|
|
||||||
p.SetAllowedGroups(o.AllowedGroups)
|
p.SetAllowedGroups(o.AllowedGroups)
|
||||||
|
|
||||||
provider := providers.New(o.ProviderType, p)
|
provider := providers.New(o.ProviderType, p)
|
||||||
|
|
@ -253,7 +263,10 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
|
||||||
p.SetRepo(o.GitHubRepo, o.GitHubToken)
|
p.SetRepo(o.GitHubRepo, o.GitHubToken)
|
||||||
p.SetUsers(o.GitHubUsers)
|
p.SetUsers(o.GitHubUsers)
|
||||||
case *providers.KeycloakProvider:
|
case *providers.KeycloakProvider:
|
||||||
p.SetGroup(o.KeycloakGroup)
|
// Backwards compatibility with `--keycloak-group` option
|
||||||
|
if len(o.KeycloakGroups) > 0 {
|
||||||
|
p.SetAllowedGroups(o.KeycloakGroups)
|
||||||
|
}
|
||||||
case *providers.GoogleProvider:
|
case *providers.GoogleProvider:
|
||||||
if o.GoogleServiceAccountJSON != "" {
|
if o.GoogleServiceAccountJSON != "" {
|
||||||
file, err := os.Open(o.GoogleServiceAccountJSON)
|
file, err := os.Open(o.GoogleServiceAccountJSON)
|
||||||
|
|
@ -273,14 +286,10 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
|
||||||
p.SetTeam(o.BitbucketTeam)
|
p.SetTeam(o.BitbucketTeam)
|
||||||
p.SetRepository(o.BitbucketRepository)
|
p.SetRepository(o.BitbucketRepository)
|
||||||
case *providers.OIDCProvider:
|
case *providers.OIDCProvider:
|
||||||
p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail
|
|
||||||
p.UserIDClaim = o.UserIDClaim
|
|
||||||
p.GroupsClaim = o.OIDCGroupsClaim
|
|
||||||
if p.Verifier == nil {
|
if p.Verifier == nil {
|
||||||
msgs = append(msgs, "oidc provider requires an oidc issuer URL")
|
msgs = append(msgs, "oidc provider requires an oidc issuer URL")
|
||||||
}
|
}
|
||||||
case *providers.GitLabProvider:
|
case *providers.GitLabProvider:
|
||||||
p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail
|
|
||||||
p.Groups = o.GitLabGroup
|
p.Groups = o.GitLabGroup
|
||||||
err := p.AddProjects(o.GitlabProjects)
|
err := p.AddProjects(o.GitlabProjects)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -20,8 +20,6 @@ type GitLabProvider struct {
|
||||||
|
|
||||||
Groups []string
|
Groups []string
|
||||||
Projects []*GitlabProject
|
Projects []*GitlabProject
|
||||||
|
|
||||||
AllowUnverifiedEmail bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GitlabProject represents a Gitlab project constraint entity
|
// GitlabProject represents a Gitlab project constraint entity
|
||||||
|
|
@ -103,7 +101,7 @@ func (p *GitLabProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("token exchange: %v", err)
|
return nil, fmt.Errorf("token exchange: %v", err)
|
||||||
}
|
}
|
||||||
s, err = p.createSessionState(ctx, token)
|
s, err = p.createSession(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to update session: %v", err)
|
return nil, fmt.Errorf("unable to update session: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -162,7 +160,7 @@ func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.Ses
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get token: %v", err)
|
return fmt.Errorf("failed to get token: %v", err)
|
||||||
}
|
}
|
||||||
newSession, err := p.createSessionState(ctx, token)
|
newSession, err := p.createSession(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to update session: %v", err)
|
return fmt.Errorf("unable to update session: %v", err)
|
||||||
}
|
}
|
||||||
|
|
@ -255,22 +253,21 @@ func (p *GitLabProvider) AddProjects(projects []string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GitLabProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
|
func (p *GitLabProvider) createSession(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
|
||||||
rawIDToken, ok := token.Extra("id_token").(string)
|
idToken, err := p.verifyIDToken(ctx, token)
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("token response did not contain an id_token")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse and verify ID Token payload.
|
|
||||||
idToken, err := p.Verifier.Verify(ctx, rawIDToken)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
switch err {
|
||||||
|
case ErrMissingIDToken:
|
||||||
|
return nil, fmt.Errorf("token response did not contain an id_token")
|
||||||
|
default:
|
||||||
return nil, fmt.Errorf("could not verify id_token: %v", err)
|
return nil, fmt.Errorf("could not verify id_token: %v", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
created := time.Now()
|
created := time.Now()
|
||||||
return &sessions.SessionState{
|
return &sessions.SessionState{
|
||||||
AccessToken: token.AccessToken,
|
AccessToken: token.AccessToken,
|
||||||
IDToken: rawIDToken,
|
IDToken: getIDToken(token),
|
||||||
RefreshToken: token.RefreshToken,
|
RefreshToken: token.RefreshToken,
|
||||||
CreatedAt: &created,
|
CreatedAt: &created,
|
||||||
ExpiresOn: &idToken.Expiry,
|
ExpiresOn: &idToken.Expiry,
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package providers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
|
|
@ -11,7 +12,6 @@ import (
|
||||||
|
|
||||||
type KeycloakProvider struct {
|
type KeycloakProvider struct {
|
||||||
*ProviderData
|
*ProviderData
|
||||||
Group string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Provider = (*KeycloakProvider)(nil)
|
var _ Provider = (*KeycloakProvider)(nil)
|
||||||
|
|
@ -47,6 +47,7 @@ var (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// NewKeycloakProvider creates a KeyCloakProvider using the passed ProviderData
|
||||||
func NewKeycloakProvider(p *ProviderData) *KeycloakProvider {
|
func NewKeycloakProvider(p *ProviderData) *KeycloakProvider {
|
||||||
p.setProviderDefaults(providerDefaults{
|
p.setProviderDefaults(providerDefaults{
|
||||||
name: keycloakProviderName,
|
name: keycloakProviderName,
|
||||||
|
|
@ -59,41 +60,39 @@ func NewKeycloakProvider(p *ProviderData) *KeycloakProvider {
|
||||||
return &KeycloakProvider{ProviderData: p}
|
return &KeycloakProvider{ProviderData: p}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *KeycloakProvider) SetGroup(group string) {
|
// EnrichSession uses the Keycloak userinfo endpoint to populate the session's
|
||||||
p.Group = group
|
// email and groups.
|
||||||
|
func (p *KeycloakProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
|
||||||
|
// Fallback to ValidateURL if ProfileURL not set for legacy compatibility
|
||||||
|
profileURL := p.ValidateURL.String()
|
||||||
|
if p.ProfileURL.String() != "" {
|
||||||
|
profileURL = p.ProfileURL.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
|
json, err := requests.New(profileURL).
|
||||||
json, err := requests.New(p.ValidateURL.String()).
|
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
SetHeader("Authorization", "Bearer "+s.AccessToken).
|
SetHeader("Authorization", "Bearer "+s.AccessToken).
|
||||||
Do().
|
Do().
|
||||||
UnmarshalJSON()
|
UnmarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("failed making request %s", err)
|
logger.Errorf("failed making request %v", err)
|
||||||
return "", err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Group != "" {
|
groups, err := json.Get("groups").StringArray()
|
||||||
var groups, err = json.Get("groups").Array()
|
if err == nil {
|
||||||
|
for _, group := range groups {
|
||||||
|
if group != "" {
|
||||||
|
s.Groups = append(s.Groups, group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
email, err := json.Get("email").String()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("groups not found %s", err)
|
return fmt.Errorf("unable to extract email from userinfo endpoint: %v", err)
|
||||||
return "", err
|
|
||||||
}
|
}
|
||||||
|
s.Email = email
|
||||||
|
|
||||||
var found = false
|
return nil
|
||||||
for i := range groups {
|
|
||||||
if groups[i].(string) == p.Group {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
logger.Printf("group not found, access denied")
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return json.Get("email").String()
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,17 +2,24 @@ package providers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/ginkgo/extensions/table"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func testKeycloakProvider(hostname, group string) *KeycloakProvider {
|
const (
|
||||||
|
keycloakAccessToken = "eyJKeycloak.eyJAccess.Token"
|
||||||
|
keycloakUserinfoPath = "/api/v3/user"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testKeycloakProvider(backend *httptest.Server) (*KeycloakProvider, error) {
|
||||||
p := NewKeycloakProvider(
|
p := NewKeycloakProvider(
|
||||||
&ProviderData{
|
&ProviderData{
|
||||||
ProviderName: "",
|
ProviderName: "",
|
||||||
|
|
@ -22,63 +29,35 @@ func testKeycloakProvider(hostname, group string) *KeycloakProvider {
|
||||||
ValidateURL: &url.URL{},
|
ValidateURL: &url.URL{},
|
||||||
Scope: ""})
|
Scope: ""})
|
||||||
|
|
||||||
if group != "" {
|
if backend != nil {
|
||||||
p.SetGroup(group)
|
bURL, err := url.Parse(backend.URL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
hostname := bURL.Host
|
||||||
|
|
||||||
if hostname != "" {
|
|
||||||
updateURL(p.Data().LoginURL, hostname)
|
updateURL(p.Data().LoginURL, hostname)
|
||||||
updateURL(p.Data().RedeemURL, hostname)
|
updateURL(p.Data().RedeemURL, hostname)
|
||||||
updateURL(p.Data().ProfileURL, hostname)
|
updateURL(p.Data().ProfileURL, hostname)
|
||||||
updateURL(p.Data().ValidateURL, hostname)
|
updateURL(p.Data().ValidateURL, hostname)
|
||||||
}
|
}
|
||||||
return p
|
|
||||||
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func testKeycloakBackend(payload string) *httptest.Server {
|
var _ = Describe("Keycloak Provider Tests", func() {
|
||||||
path := "/api/v3/user"
|
Context("New Provider Init", func() {
|
||||||
|
It("uses defaults", func() {
|
||||||
return httptest.NewServer(http.HandlerFunc(
|
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
url := r.URL
|
|
||||||
if url.Path != path {
|
|
||||||
w.WriteHeader(404)
|
|
||||||
} else if !IsAuthorizedInHeader(r.Header) {
|
|
||||||
w.WriteHeader(403)
|
|
||||||
} else {
|
|
||||||
w.WriteHeader(200)
|
|
||||||
w.Write([]byte(payload))
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestKeycloakProviderDefaults(t *testing.T) {
|
|
||||||
p := testKeycloakProvider("", "")
|
|
||||||
assert.NotEqual(t, nil, p)
|
|
||||||
assert.Equal(t, "Keycloak", p.Data().ProviderName)
|
|
||||||
assert.Equal(t, "https://keycloak.org/oauth/authorize",
|
|
||||||
p.Data().LoginURL.String())
|
|
||||||
assert.Equal(t, "https://keycloak.org/oauth/token",
|
|
||||||
p.Data().RedeemURL.String())
|
|
||||||
assert.Equal(t, "https://keycloak.org/api/v3/user",
|
|
||||||
p.Data().ValidateURL.String())
|
|
||||||
assert.Equal(t, "api", p.Data().Scope)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewKeycloakProvider(t *testing.T) {
|
|
||||||
g := NewWithT(t)
|
|
||||||
|
|
||||||
// Test that defaults are set when calling for a new provider with nothing set
|
|
||||||
providerData := NewKeycloakProvider(&ProviderData{}).Data()
|
providerData := NewKeycloakProvider(&ProviderData{}).Data()
|
||||||
g.Expect(providerData.ProviderName).To(Equal("Keycloak"))
|
Expect(providerData.ProviderName).To(Equal("Keycloak"))
|
||||||
g.Expect(providerData.LoginURL.String()).To(Equal("https://keycloak.org/oauth/authorize"))
|
Expect(providerData.LoginURL.String()).To(Equal("https://keycloak.org/oauth/authorize"))
|
||||||
g.Expect(providerData.RedeemURL.String()).To(Equal("https://keycloak.org/oauth/token"))
|
Expect(providerData.RedeemURL.String()).To(Equal("https://keycloak.org/oauth/token"))
|
||||||
g.Expect(providerData.ProfileURL.String()).To(Equal(""))
|
Expect(providerData.ProfileURL.String()).To(Equal(""))
|
||||||
g.Expect(providerData.ValidateURL.String()).To(Equal("https://keycloak.org/api/v3/user"))
|
Expect(providerData.ValidateURL.String()).To(Equal("https://keycloak.org/api/v3/user"))
|
||||||
g.Expect(providerData.Scope).To(Equal("api"))
|
Expect(providerData.Scope).To(Equal("api"))
|
||||||
}
|
})
|
||||||
|
|
||||||
func TestKeycloakProviderOverrides(t *testing.T) {
|
It("overrides defaults", func() {
|
||||||
p := NewKeycloakProvider(
|
p := NewKeycloakProvider(
|
||||||
&ProviderData{
|
&ProviderData{
|
||||||
LoginURL: &url.URL{
|
LoginURL: &url.URL{
|
||||||
|
|
@ -89,75 +68,143 @@ func TestKeycloakProviderOverrides(t *testing.T) {
|
||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
Host: "example.com",
|
Host: "example.com",
|
||||||
Path: "/oauth/token"},
|
Path: "/oauth/token"},
|
||||||
|
ProfileURL: &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "example.com",
|
||||||
|
Path: "/api/v3/user"},
|
||||||
ValidateURL: &url.URL{
|
ValidateURL: &url.URL{
|
||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
Host: "example.com",
|
Host: "example.com",
|
||||||
Path: "/api/v3/user"},
|
Path: "/api/v3/user"},
|
||||||
Scope: "profile"})
|
Scope: "profile"})
|
||||||
assert.NotEqual(t, nil, p)
|
providerData := p.Data()
|
||||||
assert.Equal(t, "Keycloak", p.Data().ProviderName)
|
|
||||||
assert.Equal(t, "https://example.com/oauth/auth",
|
Expect(providerData.ProviderName).To(Equal("Keycloak"))
|
||||||
p.Data().LoginURL.String())
|
Expect(providerData.LoginURL.String()).To(Equal("https://example.com/oauth/auth"))
|
||||||
assert.Equal(t, "https://example.com/oauth/token",
|
Expect(providerData.RedeemURL.String()).To(Equal("https://example.com/oauth/token"))
|
||||||
p.Data().RedeemURL.String())
|
Expect(providerData.ProfileURL.String()).To(Equal("https://example.com/api/v3/user"))
|
||||||
assert.Equal(t, "https://example.com/api/v3/user",
|
Expect(providerData.ValidateURL.String()).To(Equal("https://example.com/api/v3/user"))
|
||||||
p.Data().ValidateURL.String())
|
Expect(providerData.Scope).To(Equal("profile"))
|
||||||
assert.Equal(t, "profile", p.Data().Scope)
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("EnrichSession", func() {
|
||||||
|
type enrichSessionTableInput struct {
|
||||||
|
backendHandler http.HandlerFunc
|
||||||
|
expectedError error
|
||||||
|
expectedEmail string
|
||||||
|
expectedGroups []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestKeycloakProviderGetEmailAddress(t *testing.T) {
|
DescribeTable("should return expected results",
|
||||||
b := testKeycloakBackend("{\"email\": \"michael.bland@gsa.gov\"}")
|
func(in enrichSessionTableInput) {
|
||||||
defer b.Close()
|
backend := httptest.NewServer(in.backendHandler)
|
||||||
|
p, err := testKeycloakProvider(backend)
|
||||||
|
Expect(err).To(BeNil())
|
||||||
|
|
||||||
bURL, _ := url.Parse(b.URL)
|
p.ProfileURL, err = url.Parse(
|
||||||
p := testKeycloakProvider(bURL.Host, "")
|
fmt.Sprintf("%s%s", backend.URL, keycloakUserinfoPath),
|
||||||
|
)
|
||||||
|
Expect(err).To(BeNil())
|
||||||
|
|
||||||
session := CreateAuthorizedSession()
|
session := &sessions.SessionState{AccessToken: keycloakAccessToken}
|
||||||
email, err := p.GetEmailAddress(context.Background(), session)
|
err = p.EnrichSession(context.Background(), session)
|
||||||
assert.Equal(t, nil, err)
|
|
||||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
if in.expectedError != nil {
|
||||||
|
Expect(err).To(Equal(in.expectedError))
|
||||||
|
} else {
|
||||||
|
Expect(err).To(BeNil())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestKeycloakProviderGetEmailAddressAndGroup(t *testing.T) {
|
Expect(session.Email).To(Equal(in.expectedEmail))
|
||||||
b := testKeycloakBackend("{\"email\": \"michael.bland@gsa.gov\", \"groups\": [\"test-grp1\", \"test-grp2\"]}")
|
|
||||||
defer b.Close()
|
|
||||||
|
|
||||||
bURL, _ := url.Parse(b.URL)
|
if in.expectedGroups != nil {
|
||||||
p := testKeycloakProvider(bURL.Host, "test-grp1")
|
Expect(session.Groups).To(Equal(in.expectedGroups))
|
||||||
|
} else {
|
||||||
session := CreateAuthorizedSession()
|
Expect(session.Groups).To(BeNil())
|
||||||
email, err := p.GetEmailAddress(context.Background(), session)
|
|
||||||
assert.Equal(t, nil, err)
|
|
||||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
|
||||||
}
|
}
|
||||||
|
},
|
||||||
// Note that trying to trigger the "failed building request" case is not
|
Entry("email and multiple groups", enrichSessionTableInput{
|
||||||
// practical, since the only way it can fail is if the URL fails to parse.
|
backendHandler: func(w http.ResponseWriter, _ *http.Request) {
|
||||||
func TestKeycloakProviderGetEmailAddressFailedRequest(t *testing.T) {
|
w.WriteHeader(200)
|
||||||
b := testKeycloakBackend("unused payload")
|
_, err := w.Write([]byte(`
|
||||||
defer b.Close()
|
{
|
||||||
|
"email": "michael.bland@gsa.gov",
|
||||||
bURL, _ := url.Parse(b.URL)
|
"groups": [
|
||||||
p := testKeycloakProvider(bURL.Host, "")
|
"test-grp1",
|
||||||
|
"test-grp2"
|
||||||
// We'll trigger a request failure by using an unexpected access
|
]
|
||||||
// token. Alternatively, we could allow the parsing of the payload as
|
|
||||||
// JSON to fail.
|
|
||||||
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
|
||||||
email, err := p.GetEmailAddress(context.Background(), session)
|
|
||||||
assert.NotEqual(t, nil, err)
|
|
||||||
assert.Equal(t, "", email)
|
|
||||||
}
|
}
|
||||||
|
`))
|
||||||
func TestKeycloakProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
if err != nil {
|
||||||
b := testKeycloakBackend("{\"foo\": \"bar\"}")
|
panic(err)
|
||||||
defer b.Close()
|
|
||||||
|
|
||||||
bURL, _ := url.Parse(b.URL)
|
|
||||||
p := testKeycloakProvider(bURL.Host, "")
|
|
||||||
|
|
||||||
session := CreateAuthorizedSession()
|
|
||||||
email, err := p.GetEmailAddress(context.Background(), session)
|
|
||||||
assert.NotEqual(t, nil, err)
|
|
||||||
assert.Equal(t, "", email)
|
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
expectedError: nil,
|
||||||
|
expectedEmail: "michael.bland@gsa.gov",
|
||||||
|
expectedGroups: []string{"test-grp1", "test-grp2"},
|
||||||
|
}),
|
||||||
|
Entry("email and single group", enrichSessionTableInput{
|
||||||
|
backendHandler: func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
_, err := w.Write([]byte(`
|
||||||
|
{
|
||||||
|
"email": "michael.bland@gsa.gov",
|
||||||
|
"groups": ["test-grp1"]
|
||||||
|
}
|
||||||
|
`))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
expectedError: nil,
|
||||||
|
expectedEmail: "michael.bland@gsa.gov",
|
||||||
|
expectedGroups: []string{"test-grp1"},
|
||||||
|
}),
|
||||||
|
Entry("email and no groups", enrichSessionTableInput{
|
||||||
|
backendHandler: func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
_, err := w.Write([]byte(`
|
||||||
|
{
|
||||||
|
"email": "michael.bland@gsa.gov"
|
||||||
|
}
|
||||||
|
`))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
expectedError: nil,
|
||||||
|
expectedEmail: "michael.bland@gsa.gov",
|
||||||
|
expectedGroups: nil,
|
||||||
|
}),
|
||||||
|
Entry("missing email", enrichSessionTableInput{
|
||||||
|
backendHandler: func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
_, err := w.Write([]byte(`
|
||||||
|
{
|
||||||
|
"groups": [
|
||||||
|
"test-grp1",
|
||||||
|
"test-grp2"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
`))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
expectedError: errors.New(
|
||||||
|
"unable to extract email from userinfo endpoint: type assertion to string failed"),
|
||||||
|
expectedEmail: "",
|
||||||
|
expectedGroups: []string{"test-grp1", "test-grp2"},
|
||||||
|
}),
|
||||||
|
Entry("request failure", enrichSessionTableInput{
|
||||||
|
backendHandler: func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(500)
|
||||||
|
},
|
||||||
|
expectedError: errors.New(`unexpected status "500": `),
|
||||||
|
expectedEmail: "",
|
||||||
|
expectedGroups: nil,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
|
||||||
|
|
@ -2,29 +2,20 @@ package providers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
oidc "github.com/coreos/go-oidc"
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
const emailClaim = "email"
|
|
||||||
|
|
||||||
// OIDCProvider represents an OIDC based Identity Provider
|
// OIDCProvider represents an OIDC based Identity Provider
|
||||||
type OIDCProvider struct {
|
type OIDCProvider struct {
|
||||||
*ProviderData
|
*ProviderData
|
||||||
|
|
||||||
AllowUnverifiedEmail bool
|
|
||||||
UserIDClaim string
|
|
||||||
GroupsClaim string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOIDCProvider initiates a new OIDCProvider
|
// NewOIDCProvider initiates a new OIDCProvider
|
||||||
|
|
@ -36,10 +27,10 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider {
|
||||||
var _ Provider = (*OIDCProvider)(nil)
|
var _ Provider = (*OIDCProvider)(nil)
|
||||||
|
|
||||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||||
clientSecret, err := p.GetClientSecret()
|
clientSecret, err := p.GetClientSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c := oauth2.Config{
|
c := oauth2.Config{
|
||||||
|
|
@ -52,23 +43,74 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s
|
||||||
}
|
}
|
||||||
token, err := c.Exchange(ctx, code)
|
token, err := c.Exchange(ctx, code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("token exchange: %v", err)
|
return nil, fmt.Errorf("token exchange failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// in the initial exchange the id token is mandatory
|
return p.createSession(ctx, token, false)
|
||||||
idToken, err := p.findVerifiedIDToken(ctx, token)
|
}
|
||||||
|
|
||||||
|
// EnrichSession is called after Redeem to allow providers to enrich session fields
|
||||||
|
// such as User, Email, Groups with provider specific API calls.
|
||||||
|
func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
|
||||||
|
if p.ProfileURL.String() == "" {
|
||||||
|
if s.Email == "" {
|
||||||
|
return errors.New("id_token did not contain an email and profileURL is not defined")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get missing emails or groups from a profileURL
|
||||||
|
if s.Email == "" || s.Groups == nil {
|
||||||
|
err := p.enrichFromProfileURL(ctx, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not verify id_token: %v", err)
|
logger.Errorf("Warning: Profile URL request failed: %v", err)
|
||||||
} else if idToken == nil {
|
}
|
||||||
return nil, fmt.Errorf("token response did not contain an id_token")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s, err = p.createSessionState(ctx, token, idToken)
|
// If a mandatory email wasn't set, error at this point.
|
||||||
|
if s.Email == "" {
|
||||||
|
return errors.New("neither the id_token nor the profileURL set an email")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// enrichFromProfileURL enriches a session's Email & Groups via the JSON response of
|
||||||
|
// an OIDC profile URL
|
||||||
|
func (p *OIDCProvider) enrichFromProfileURL(ctx context.Context, s *sessions.SessionState) error {
|
||||||
|
respJSON, err := requests.New(p.ProfileURL.String()).
|
||||||
|
WithContext(ctx).
|
||||||
|
WithHeaders(makeOIDCHeader(s.AccessToken)).
|
||||||
|
Do().
|
||||||
|
UnmarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to update session: %v", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
email, err := respJSON.Get(p.EmailClaim).String()
|
||||||
|
if err == nil && s.Email == "" {
|
||||||
|
s.Email = email
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.Groups) > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, group := range coerceArray(respJSON, p.GroupsClaim) {
|
||||||
|
formatted, err := formatGroup(group)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Warning: unable to format group of type %s with error %s",
|
||||||
|
reflect.TypeOf(group), err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.Groups = append(s.Groups, formatted)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateSession checks that the session's IDToken is still valid
|
||||||
|
func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
||||||
|
_, err := p.Verifier.Verify(ctx, s.IDToken)
|
||||||
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||||
|
|
@ -83,14 +125,16 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S
|
||||||
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("refreshed access token %s (expired on %s)\n", s, s.ExpiresOn)
|
logger.Printf("refreshed session: %s", s)
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
// redeemRefreshToken uses a RefreshToken with the RedeemURL to refresh the
|
||||||
|
// Access Token and (probably) the ID Token.
|
||||||
|
func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
|
||||||
clientSecret, err := p.GetClientSecret()
|
clientSecret, err := p.GetClientSecret()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c := oauth2.Config{
|
c := oauth2.Config{
|
||||||
|
|
@ -109,19 +153,14 @@ func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sessi
|
||||||
return fmt.Errorf("failed to get token: %v", err)
|
return fmt.Errorf("failed to get token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// in the token refresh response the id_token is optional
|
newSession, err := p.createSession(ctx, token, true)
|
||||||
idToken, err := p.findVerifiedIDToken(ctx, token)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to extract id_token from response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
newSession, err := p.createSessionState(ctx, token, idToken)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable create new session state from response: %v", err)
|
return fmt.Errorf("unable create new session state from response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// It's possible that if the refresh token isn't in the token response the session will not contain an id token
|
// It's possible that if the refresh token isn't in the token response the
|
||||||
// if it doesn't it's probably better to retain the old one
|
// session will not contain an id token.
|
||||||
|
// If it doesn't it's probably better to retain the old one
|
||||||
if newSession.IDToken != "" {
|
if newSession.IDToken != "" {
|
||||||
s.IDToken = newSession.IDToken
|
s.IDToken = newSession.IDToken
|
||||||
s.Email = newSession.Email
|
s.Email = newSession.Email
|
||||||
|
|
@ -135,193 +174,62 @@ func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sessi
|
||||||
s.CreatedAt = newSession.CreatedAt
|
s.CreatedAt = newSession.CreatedAt
|
||||||
s.ExpiresOn = newSession.ExpiresOn
|
s.ExpiresOn = newSession.ExpiresOn
|
||||||
|
|
||||||
return
|
return nil
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
|
|
||||||
|
|
||||||
getIDToken := func() (string, bool) {
|
|
||||||
rawIDToken, _ := token.Extra("id_token").(string)
|
|
||||||
return rawIDToken, len(strings.TrimSpace(rawIDToken)) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
if rawIDToken, present := getIDToken(); present {
|
|
||||||
verifiedIDToken, err := p.Verifier.Verify(ctx, rawIDToken)
|
|
||||||
return verifiedIDToken, err
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) {
|
|
||||||
|
|
||||||
var newSession *sessions.SessionState
|
|
||||||
|
|
||||||
if idToken == nil {
|
|
||||||
newSession = &sessions.SessionState{}
|
|
||||||
} else {
|
|
||||||
var err error
|
|
||||||
newSession, err = p.createSessionStateInternal(ctx, idToken, token)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
created := time.Now()
|
|
||||||
newSession.AccessToken = token.AccessToken
|
|
||||||
newSession.RefreshToken = token.RefreshToken
|
|
||||||
newSession.CreatedAt = &created
|
|
||||||
newSession.ExpiresOn = &token.Expiry
|
|
||||||
return newSession, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateSessionFromToken converts Bearer IDTokens into sessions
|
||||||
func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) {
|
func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) {
|
||||||
idToken, err := p.Verifier.Verify(ctx, token)
|
idToken, err := p.Verifier.Verify(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
newSession, err := p.createSessionStateInternal(ctx, idToken, nil)
|
ss, err := p.buildSessionFromClaims(idToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
newSession.AccessToken = token
|
// Allow empty Email in Bearer case since we can't hit the ProfileURL
|
||||||
newSession.IDToken = token
|
if ss.Email == "" {
|
||||||
newSession.RefreshToken = ""
|
ss.Email = ss.User
|
||||||
newSession.ExpiresOn = &idToken.Expiry
|
|
||||||
|
|
||||||
return newSession, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OIDCProvider) createSessionStateInternal(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) {
|
ss.AccessToken = token
|
||||||
|
ss.IDToken = token
|
||||||
|
ss.RefreshToken = ""
|
||||||
|
ss.ExpiresOn = &idToken.Expiry
|
||||||
|
|
||||||
newSession := &sessions.SessionState{}
|
return ss, nil
|
||||||
|
|
||||||
if idToken == nil {
|
|
||||||
return newSession, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, err := p.findClaimsFromIDToken(ctx, idToken, token)
|
// createSession takes an oauth2.Token and creates a SessionState from it.
|
||||||
|
// It alters behavior if called from Redeem vs Refresh
|
||||||
|
func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) {
|
||||||
|
idToken, err := p.verifyIDToken(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err)
|
switch err {
|
||||||
|
case ErrMissingIDToken:
|
||||||
|
// IDToken is mandatory in Redeem but optional in Refresh
|
||||||
|
if !refresh {
|
||||||
|
return nil, errors.New("token response did not contain an id_token")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("could not verify id_token: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if token != nil {
|
ss, err := p.buildSessionFromClaims(idToken)
|
||||||
newSession.IDToken = token.Extra("id_token").(string)
|
|
||||||
}
|
|
||||||
|
|
||||||
newSession.Email = claims.UserID // TODO Rename SessionState.Email to .UserID in the near future
|
|
||||||
|
|
||||||
newSession.User = claims.Subject
|
|
||||||
newSession.Groups = claims.Groups
|
|
||||||
newSession.PreferredUsername = claims.PreferredUsername
|
|
||||||
|
|
||||||
verifyEmail := (p.UserIDClaim == emailClaim) && !p.AllowUnverifiedEmail
|
|
||||||
if verifyEmail && claims.Verified != nil && !*claims.Verified {
|
|
||||||
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.UserID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return newSession, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateSessionState checks that the session's IDToken is still valid
|
|
||||||
func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
|
||||||
_, err := p.Verifier.Verify(ctx, s.IDToken)
|
|
||||||
return err == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*OIDCClaims, error) {
|
|
||||||
claims := &OIDCClaims{}
|
|
||||||
|
|
||||||
// Extract default claims.
|
|
||||||
if err := idToken.Claims(&claims); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse default id_token claims: %v", err)
|
|
||||||
}
|
|
||||||
// Extract custom claims.
|
|
||||||
if err := idToken.Claims(&claims.rawClaims); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse all id_token claims: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
userID := claims.rawClaims[p.UserIDClaim]
|
|
||||||
if userID != nil {
|
|
||||||
claims.UserID = fmt.Sprint(userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
claims.Groups = p.extractGroupsFromRawClaims(claims.rawClaims)
|
|
||||||
|
|
||||||
// userID claim was not present or was empty in the ID Token
|
|
||||||
if claims.UserID == "" {
|
|
||||||
// BearerToken case, allow empty UserID
|
|
||||||
// ProfileURL checks below won't work since we don't have an access token
|
|
||||||
if token == nil {
|
|
||||||
claims.UserID = claims.Subject
|
|
||||||
return claims, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
profileURL := p.ProfileURL.String()
|
|
||||||
if profileURL == "" || token.AccessToken == "" {
|
|
||||||
return nil, fmt.Errorf("id_token did not contain user ID claim (%q)", p.UserIDClaim)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
|
|
||||||
// contents at the profileURL contains the email.
|
|
||||||
// Make a query to the userinfo endpoint, and attempt to locate the email from there.
|
|
||||||
respJSON, err := requests.New(profileURL).
|
|
||||||
WithContext(ctx).
|
|
||||||
WithHeaders(makeOIDCHeader(token.AccessToken)).
|
|
||||||
Do().
|
|
||||||
UnmarshalJSON()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, err := respJSON.Get(p.UserIDClaim).String()
|
ss.AccessToken = token.AccessToken
|
||||||
if err != nil {
|
ss.RefreshToken = token.RefreshToken
|
||||||
return nil, fmt.Errorf("neither id_token nor userinfo endpoint contained user ID claim (%q)", p.UserIDClaim)
|
ss.IDToken = getIDToken(token)
|
||||||
}
|
|
||||||
|
|
||||||
claims.UserID = userID
|
created := time.Now()
|
||||||
}
|
ss.CreatedAt = &created
|
||||||
|
ss.ExpiresOn = &token.Expiry
|
||||||
|
|
||||||
return claims, nil
|
return ss, nil
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OIDCProvider) extractGroupsFromRawClaims(rawClaims map[string]interface{}) []string {
|
|
||||||
groups := []string{}
|
|
||||||
|
|
||||||
rawGroups, ok := rawClaims[p.GroupsClaim].([]interface{})
|
|
||||||
if rawGroups != nil && ok {
|
|
||||||
for _, rawGroup := range rawGroups {
|
|
||||||
formattedGroup, err := formatGroup(rawGroup)
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("unable to format group of type %s with error %s", reflect.TypeOf(rawGroup), err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
groups = append(groups, formattedGroup)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return groups
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatGroup(rawGroup interface{}) (string, error) {
|
|
||||||
group, ok := rawGroup.(string)
|
|
||||||
if !ok {
|
|
||||||
jsonGroup, err := json.Marshal(rawGroup)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
group = string(jsonGroup)
|
|
||||||
}
|
|
||||||
return group, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type OIDCClaims struct {
|
|
||||||
rawClaims map[string]interface{}
|
|
||||||
UserID string
|
|
||||||
Subject string `json:"sub"`
|
|
||||||
Verified *bool `json:"email_verified"`
|
|
||||||
PreferredUsername string `json:"preferred_username"`
|
|
||||||
Groups []string `json:"-"`
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,42 +2,18 @@ package providers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/coreos/go-oidc"
|
"github.com/coreos/go-oidc"
|
||||||
"github.com/dgrijalva/jwt-go"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
const accessToken = "access_token"
|
|
||||||
const refreshToken = "refresh_token"
|
|
||||||
const clientID = "https://test.myapp.com"
|
|
||||||
const secret = "secret"
|
|
||||||
|
|
||||||
type idTokenClaims struct {
|
|
||||||
Name string `json:"name,omitempty"`
|
|
||||||
Email string `json:"email,omitempty"`
|
|
||||||
Phone string `json:"phone_number,omitempty"`
|
|
||||||
Picture string `json:"picture,omitempty"`
|
|
||||||
Groups interface{} `json:"groups,omitempty"`
|
|
||||||
OtherGroups interface{} `json:"other_groups,omitempty"`
|
|
||||||
jwt.StandardClaims
|
|
||||||
}
|
|
||||||
|
|
||||||
type redeemTokenResponse struct {
|
type redeemTokenResponse struct {
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
RefreshToken string `json:"refresh_token"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
|
@ -46,88 +22,12 @@ type redeemTokenResponse struct {
|
||||||
IDToken string `json:"id_token,omitempty"`
|
IDToken string `json:"id_token,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultIDToken idTokenClaims = idTokenClaims{
|
|
||||||
"Jane Dobbs",
|
|
||||||
"janed@me.com",
|
|
||||||
"+4798765432",
|
|
||||||
"http://mugbook.com/janed/me.jpg",
|
|
||||||
[]string{"test:a", "test:b"},
|
|
||||||
[]string{"test:c", "test:d"},
|
|
||||||
jwt.StandardClaims{
|
|
||||||
Audience: "https://test.myapp.com",
|
|
||||||
ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(),
|
|
||||||
Id: "id-some-id",
|
|
||||||
IssuedAt: time.Now().Unix(),
|
|
||||||
Issuer: "https://issuer.example.com",
|
|
||||||
NotBefore: 0,
|
|
||||||
Subject: "123456789",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var customGroupClaimIDToken idTokenClaims = idTokenClaims{
|
|
||||||
"Jane Dobbs",
|
|
||||||
"janed@me.com",
|
|
||||||
"+4798765432",
|
|
||||||
"http://mugbook.com/janed/me.jpg",
|
|
||||||
[]map[string]interface{}{
|
|
||||||
{
|
|
||||||
"groupId": "Admin Group Id",
|
|
||||||
"roles": []string{"Admin"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
[]string{"test:c", "test:d"},
|
|
||||||
jwt.StandardClaims{
|
|
||||||
Audience: "https://test.myapp.com",
|
|
||||||
ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(),
|
|
||||||
Id: "id-some-id",
|
|
||||||
IssuedAt: time.Now().Unix(),
|
|
||||||
Issuer: "https://issuer.example.com",
|
|
||||||
NotBefore: 0,
|
|
||||||
Subject: "123456789",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var minimalIDToken idTokenClaims = idTokenClaims{
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
[]string{},
|
|
||||||
[]string{},
|
|
||||||
jwt.StandardClaims{
|
|
||||||
Audience: "https://test.myapp.com",
|
|
||||||
ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(),
|
|
||||||
Id: "id-some-id",
|
|
||||||
IssuedAt: time.Now().Unix(),
|
|
||||||
Issuer: "https://issuer.example.com",
|
|
||||||
NotBefore: 0,
|
|
||||||
Subject: "minimal",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeKeySetStub struct{}
|
|
||||||
|
|
||||||
func (fakeKeySetStub) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) {
|
|
||||||
decodeString, err := base64.RawURLEncoding.DecodeString(strings.Split(jwt, ".")[1])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tokenClaims := &idTokenClaims{}
|
|
||||||
err = json.Unmarshal(decodeString, tokenClaims)
|
|
||||||
|
|
||||||
if err != nil || tokenClaims.Id == "this-id-fails-validation" {
|
|
||||||
return nil, fmt.Errorf("the validation failed for subject [%v]", tokenClaims.Subject)
|
|
||||||
}
|
|
||||||
|
|
||||||
return decodeString, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func newOIDCProvider(serverURL *url.URL) *OIDCProvider {
|
func newOIDCProvider(serverURL *url.URL) *OIDCProvider {
|
||||||
|
|
||||||
providerData := &ProviderData{
|
providerData := &ProviderData{
|
||||||
ProviderName: "oidc",
|
ProviderName: "oidc",
|
||||||
ClientID: clientID,
|
ClientID: oidcClientID,
|
||||||
ClientSecret: secret,
|
ClientSecret: oidcSecret,
|
||||||
LoginURL: &url.URL{
|
LoginURL: &url.URL{
|
||||||
Scheme: serverURL.Scheme,
|
Scheme: serverURL.Scheme,
|
||||||
Host: serverURL.Host,
|
Host: serverURL.Host,
|
||||||
|
|
@ -145,17 +45,16 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider {
|
||||||
Host: serverURL.Host,
|
Host: serverURL.Host,
|
||||||
Path: "/api"},
|
Path: "/api"},
|
||||||
Scope: "openid profile offline_access",
|
Scope: "openid profile offline_access",
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
Verifier: oidc.NewVerifier(
|
Verifier: oidc.NewVerifier(
|
||||||
"https://issuer.example.com",
|
oidcIssuer,
|
||||||
fakeKeySetStub{},
|
mockJWKS{},
|
||||||
&oidc.Config{ClientID: clientID},
|
&oidc.Config{ClientID: oidcClientID},
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
p := &OIDCProvider{
|
p := &OIDCProvider{ProviderData: providerData}
|
||||||
ProviderData: providerData,
|
|
||||||
UserIDClaim: "email",
|
|
||||||
}
|
|
||||||
|
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
@ -169,22 +68,7 @@ func newOIDCServer(body []byte) (*url.URL, *httptest.Server) {
|
||||||
return u, s
|
return u, s
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSignedTestIDToken(tokenClaims idTokenClaims) (string, error) {
|
func newTestOIDCSetup(body []byte) (*httptest.Server, *OIDCProvider) {
|
||||||
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
standardClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, tokenClaims)
|
|
||||||
return standardClaims.SignedString(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newOauth2Token() *oauth2.Token {
|
|
||||||
return &oauth2.Token{
|
|
||||||
AccessToken: accessToken,
|
|
||||||
TokenType: "Bearer",
|
|
||||||
RefreshToken: refreshToken,
|
|
||||||
Expiry: time.Time{}.Add(time.Duration(5) * time.Second),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTestSetup(body []byte) (*httptest.Server, *OIDCProvider) {
|
|
||||||
redeemURL, server := newOIDCServer(body)
|
redeemURL, server := newOIDCServer(body)
|
||||||
provider := newOIDCProvider(redeemURL)
|
provider := newOIDCProvider(redeemURL)
|
||||||
return server, provider
|
return server, provider
|
||||||
|
|
@ -201,7 +85,7 @@ func TestOIDCProviderRedeem(t *testing.T) {
|
||||||
IDToken: idToken,
|
IDToken: idToken,
|
||||||
})
|
})
|
||||||
|
|
||||||
server, provider := newTestSetup(body)
|
server, provider := newTestOIDCSetup(body)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234")
|
session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234")
|
||||||
|
|
@ -224,8 +108,8 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) {
|
||||||
IDToken: idToken,
|
IDToken: idToken,
|
||||||
})
|
})
|
||||||
|
|
||||||
server, provider := newTestSetup(body)
|
server, provider := newTestOIDCSetup(body)
|
||||||
provider.UserIDClaim = "phone_number"
|
provider.EmailClaim = "phone_number"
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234")
|
session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234")
|
||||||
|
|
@ -233,6 +117,333 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) {
|
||||||
assert.Equal(t, defaultIDToken.Phone, session.Email)
|
assert.Equal(t, defaultIDToken.Phone, session.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOIDCProvider_EnrichSession(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
ExistingSession *sessions.SessionState
|
||||||
|
EmailClaim string
|
||||||
|
GroupsClaim string
|
||||||
|
ProfileJSON map[string]interface{}
|
||||||
|
ExpectedError error
|
||||||
|
ExpectedSession *sessions.SessionState
|
||||||
|
}{
|
||||||
|
"Already Populated": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: []string{"already", "populated"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"email": "new@thing.com",
|
||||||
|
"groups": []string{"new", "thing"},
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: []string{"already", "populated"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Missing Email": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "missing.email",
|
||||||
|
Groups: []string{"already", "populated"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"email": "found@email.com",
|
||||||
|
"groups": []string{"new", "thing"},
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "missing.email",
|
||||||
|
Email: "found@email.com",
|
||||||
|
Groups: []string{"already", "populated"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
"Missing Email Only in Profile URL": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "missing.email",
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"email": "found@email.com",
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "missing.email",
|
||||||
|
Email: "found@email.com",
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Missing Email with Custom Claim": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "missing.email",
|
||||||
|
Groups: []string{"already", "populated"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "weird",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"weird": "weird@claim.com",
|
||||||
|
"groups": []string{"new", "thing"},
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "missing.email",
|
||||||
|
Email: "weird@claim.com",
|
||||||
|
Groups: []string{"already", "populated"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Missing Email not in Profile URL": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "missing.email",
|
||||||
|
Groups: []string{"already", "populated"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"groups": []string{"new", "thing"},
|
||||||
|
},
|
||||||
|
ExpectedError: errors.New("neither the id_token nor the profileURL set an email"),
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "missing.email",
|
||||||
|
Groups: []string{"already", "populated"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Missing Groups": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: nil,
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"email": "new@thing.com",
|
||||||
|
"groups": []string{"new", "thing"},
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: []string{"new", "thing"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Missing Groups with Complex Groups in Profile URL": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: nil,
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"email": "new@thing.com",
|
||||||
|
"groups": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"groupId": "Admin Group Id",
|
||||||
|
"roles": []string{"Admin"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Missing Groups with Singleton Complex Group in Profile URL": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: nil,
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"email": "new@thing.com",
|
||||||
|
"groups": map[string]interface{}{
|
||||||
|
"groupId": "Admin Group Id",
|
||||||
|
"roles": []string{"Admin"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Empty Groups Claims": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: []string{},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"email": "new@thing.com",
|
||||||
|
"groups": []string{"new", "thing"},
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: []string{},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Missing Groups with Custom Claim": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: nil,
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "roles",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"email": "new@thing.com",
|
||||||
|
"roles": []string{"new", "thing", "roles"},
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: []string{"new", "thing", "roles"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Missing Groups String Profile URL Response": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: nil,
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"email": "new@thing.com",
|
||||||
|
"groups": "singleton",
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
Groups: []string{"singleton"},
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Missing Groups in both Claims and Profile URL": {
|
||||||
|
ExistingSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ProfileJSON: map[string]interface{}{
|
||||||
|
"email": "new@thing.com",
|
||||||
|
},
|
||||||
|
ExpectedError: nil,
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "already",
|
||||||
|
Email: "already@populated.com",
|
||||||
|
IDToken: idToken,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for testName, tc := range testCases {
|
||||||
|
t.Run(testName, func(t *testing.T) {
|
||||||
|
jsonResp, err := json.Marshal(tc.ProfileJSON)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
server, provider := newTestOIDCSetup(jsonResp)
|
||||||
|
provider.ProfileURL, err = url.Parse(server.URL)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
provider.EmailClaim = tc.EmailClaim
|
||||||
|
provider.GroupsClaim = tc.GroupsClaim
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
err = provider.EnrichSession(context.Background(), tc.ExistingSession)
|
||||||
|
assert.Equal(t, tc.ExpectedError, err)
|
||||||
|
assert.Equal(t, *tc.ExpectedSession, *tc.ExistingSession)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
|
func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
|
||||||
|
|
||||||
idToken, _ := newSignedTestIDToken(defaultIDToken)
|
idToken, _ := newSignedTestIDToken(defaultIDToken)
|
||||||
|
|
@ -243,7 +454,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
})
|
})
|
||||||
|
|
||||||
server, provider := newTestSetup(body)
|
server, provider := newTestOIDCSetup(body)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
existingSession := &sessions.SessionState{
|
existingSession := &sessions.SessionState{
|
||||||
|
|
@ -277,7 +488,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
|
||||||
IDToken: idToken,
|
IDToken: idToken,
|
||||||
})
|
})
|
||||||
|
|
||||||
server, provider := newTestSetup(body)
|
server, provider := newTestOIDCSetup(body)
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
existingSession := &sessions.SessionState{
|
existingSession := &sessions.SessionState{
|
||||||
|
|
@ -300,48 +511,45 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
|
func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
|
||||||
const profileURLEmail = "janed@me.com"
|
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
IDToken idTokenClaims
|
IDToken idTokenClaims
|
||||||
GroupsClaim string
|
GroupsClaim string
|
||||||
ExpectedUser string
|
ExpectedUser string
|
||||||
ExpectedEmail string
|
ExpectedEmail string
|
||||||
ExpectedGroups interface{}
|
ExpectedGroups []string
|
||||||
}{
|
}{
|
||||||
"Default IDToken": {
|
"Default IDToken": {
|
||||||
IDToken: defaultIDToken,
|
IDToken: defaultIDToken,
|
||||||
GroupsClaim: "groups",
|
GroupsClaim: "groups",
|
||||||
ExpectedUser: defaultIDToken.Subject,
|
ExpectedUser: "123456789",
|
||||||
ExpectedEmail: defaultIDToken.Email,
|
ExpectedEmail: "janed@me.com",
|
||||||
ExpectedGroups: []string{"test:a", "test:b"},
|
ExpectedGroups: []string{"test:a", "test:b"},
|
||||||
},
|
},
|
||||||
"Minimal IDToken with no email claim": {
|
"Minimal IDToken with no email claim": {
|
||||||
IDToken: minimalIDToken,
|
IDToken: minimalIDToken,
|
||||||
GroupsClaim: "groups",
|
GroupsClaim: "groups",
|
||||||
ExpectedUser: minimalIDToken.Subject,
|
ExpectedUser: "123456789",
|
||||||
ExpectedEmail: minimalIDToken.Subject,
|
ExpectedEmail: "123456789",
|
||||||
ExpectedGroups: []string{},
|
ExpectedGroups: nil,
|
||||||
},
|
},
|
||||||
"Custom Groups Claim": {
|
"Custom Groups Claim": {
|
||||||
IDToken: defaultIDToken,
|
IDToken: defaultIDToken,
|
||||||
GroupsClaim: "other_groups",
|
GroupsClaim: "roles",
|
||||||
ExpectedUser: defaultIDToken.Subject,
|
ExpectedUser: "123456789",
|
||||||
ExpectedEmail: defaultIDToken.Email,
|
ExpectedEmail: "janed@me.com",
|
||||||
ExpectedGroups: []string{"test:c", "test:d"},
|
ExpectedGroups: []string{"test:c", "test:d"},
|
||||||
},
|
},
|
||||||
"Custom Groups Claim2": {
|
"Complex Groups Claim": {
|
||||||
IDToken: customGroupClaimIDToken,
|
IDToken: complexGroupsIDToken,
|
||||||
GroupsClaim: "groups",
|
GroupsClaim: "groups",
|
||||||
ExpectedUser: customGroupClaimIDToken.Subject,
|
ExpectedUser: "123456789",
|
||||||
ExpectedEmail: customGroupClaimIDToken.Email,
|
ExpectedEmail: "complex@claims.com",
|
||||||
ExpectedGroups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
ExpectedGroups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for testName, tc := range testCases {
|
for testName, tc := range testCases {
|
||||||
t.Run(testName, func(t *testing.T) {
|
t.Run(testName, func(t *testing.T) {
|
||||||
jsonResp := []byte(fmt.Sprintf(`{"email":"%s"}`, profileURLEmail))
|
server, provider := newTestOIDCSetup([]byte(`{}`))
|
||||||
server, provider := newTestSetup(jsonResp)
|
|
||||||
provider.GroupsClaim = tc.GroupsClaim
|
provider.GroupsClaim = tc.GroupsClaim
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
|
|
@ -353,75 +561,10 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
|
||||||
|
|
||||||
assert.Equal(t, tc.ExpectedUser, ss.User)
|
assert.Equal(t, tc.ExpectedUser, ss.User)
|
||||||
assert.Equal(t, tc.ExpectedEmail, ss.Email)
|
assert.Equal(t, tc.ExpectedEmail, ss.Email)
|
||||||
|
assert.Equal(t, tc.ExpectedGroups, ss.Groups)
|
||||||
assert.Equal(t, rawIDToken, ss.IDToken)
|
assert.Equal(t, rawIDToken, ss.IDToken)
|
||||||
assert.Equal(t, rawIDToken, ss.AccessToken)
|
assert.Equal(t, rawIDToken, ss.AccessToken)
|
||||||
assert.Equal(t, tc.ExpectedGroups, ss.Groups)
|
|
||||||
assert.Equal(t, "", ss.RefreshToken)
|
assert.Equal(t, "", ss.RefreshToken)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOIDCProvider_findVerifiedIdToken(t *testing.T) {
|
|
||||||
|
|
||||||
server, provider := newTestSetup([]byte(""))
|
|
||||||
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
token := newOauth2Token()
|
|
||||||
signedIDToken, _ := newSignedTestIDToken(defaultIDToken)
|
|
||||||
tokenWithIDToken := token.WithExtra(map[string]interface{}{
|
|
||||||
"id_token": signedIDToken,
|
|
||||||
})
|
|
||||||
|
|
||||||
verifiedIDToken, err := provider.findVerifiedIDToken(context.Background(), tokenWithIDToken)
|
|
||||||
assert.Equal(t, true, err == nil)
|
|
||||||
if verifiedIDToken == nil {
|
|
||||||
t.Fatal("verifiedIDToken is nil")
|
|
||||||
}
|
|
||||||
assert.Equal(t, defaultIDToken.Issuer, verifiedIDToken.Issuer)
|
|
||||||
assert.Equal(t, defaultIDToken.Subject, verifiedIDToken.Subject)
|
|
||||||
|
|
||||||
// When the validation fails the response should be nil
|
|
||||||
defaultIDToken.Id = "this-id-fails-validation"
|
|
||||||
signedIDToken, _ = newSignedTestIDToken(defaultIDToken)
|
|
||||||
tokenWithIDToken = token.WithExtra(map[string]interface{}{
|
|
||||||
"id_token": signedIDToken,
|
|
||||||
})
|
|
||||||
|
|
||||||
verifiedIDToken, err = provider.findVerifiedIDToken(context.Background(), tokenWithIDToken)
|
|
||||||
assert.Equal(t, errors.New("failed to verify signature: the validation failed for subject [123456789]"), err)
|
|
||||||
assert.Equal(t, true, verifiedIDToken == nil)
|
|
||||||
|
|
||||||
// When there is no id token in the oauth token
|
|
||||||
verifiedIDToken, err = provider.findVerifiedIDToken(context.Background(), newOauth2Token())
|
|
||||||
assert.Equal(t, nil, err)
|
|
||||||
assert.Equal(t, true, verifiedIDToken == nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_formatGroup(t *testing.T) {
|
|
||||||
testCases := map[string]struct {
|
|
||||||
RawGroup interface{}
|
|
||||||
ExpectedFormattedGroupValue string
|
|
||||||
}{
|
|
||||||
"String Group": {
|
|
||||||
RawGroup: "group",
|
|
||||||
ExpectedFormattedGroupValue: "group",
|
|
||||||
},
|
|
||||||
"Map Group": {
|
|
||||||
RawGroup: map[string]string{"id": "1", "name": "Test"},
|
|
||||||
ExpectedFormattedGroupValue: "{\"id\":\"1\",\"name\":\"Test\"}",
|
|
||||||
},
|
|
||||||
"List Group": {
|
|
||||||
RawGroup: []string{"First", "Second"},
|
|
||||||
ExpectedFormattedGroupValue: "[\"First\",\"Second\"]",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for testName, tc := range testCases {
|
|
||||||
t.Run(testName, func(t *testing.T) {
|
|
||||||
formattedGroup, err := formatGroup(tc.RawGroup)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, tc.ExpectedFormattedGroupValue, formattedGroup)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,23 @@
|
||||||
package providers
|
package providers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc"
|
"github.com/coreos/go-oidc"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
OIDCEmailClaim = "email"
|
||||||
|
OIDCGroupsClaim = "groups"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProviderData contains information required to configure all implementations
|
// ProviderData contains information required to configure all implementations
|
||||||
|
|
@ -27,6 +38,11 @@ type ProviderData struct {
|
||||||
ClientSecretFile string
|
ClientSecretFile string
|
||||||
Scope string
|
Scope string
|
||||||
Prompt string
|
Prompt string
|
||||||
|
|
||||||
|
// Common OIDC options for any OIDC-based providers to consume
|
||||||
|
AllowUnverifiedEmail bool
|
||||||
|
EmailClaim string
|
||||||
|
GroupsClaim string
|
||||||
Verifier *oidc.IDTokenVerifier
|
Verifier *oidc.IDTokenVerifier
|
||||||
|
|
||||||
// Universal Group authorization data structure
|
// Universal Group authorization data structure
|
||||||
|
|
@ -94,3 +110,116 @@ func defaultURL(u *url.URL, d *url.URL) *url.URL {
|
||||||
}
|
}
|
||||||
return &url.URL{}
|
return &url.URL{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ****************************************************************************
|
||||||
|
// These private OIDC helper methods are available to any providers that are
|
||||||
|
// OIDC compliant
|
||||||
|
// ****************************************************************************
|
||||||
|
|
||||||
|
// OIDCClaims is a struct to unmarshal the OIDC claims from an ID Token payload
|
||||||
|
type OIDCClaims struct {
|
||||||
|
Subject string `json:"sub"`
|
||||||
|
Email string `json:"-"`
|
||||||
|
Groups []string `json:"-"`
|
||||||
|
Verified *bool `json:"email_verified"`
|
||||||
|
|
||||||
|
raw map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
|
||||||
|
rawIDToken := getIDToken(token)
|
||||||
|
if strings.TrimSpace(rawIDToken) == "" {
|
||||||
|
return nil, ErrMissingIDToken
|
||||||
|
}
|
||||||
|
if p.Verifier == nil {
|
||||||
|
return nil, ErrMissingOIDCVerifier
|
||||||
|
}
|
||||||
|
return p.Verifier.Verify(ctx, rawIDToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildSessionFromClaims uses IDToken claims to populate a fresh SessionState
|
||||||
|
// with non-Token related fields.
|
||||||
|
func (p *ProviderData) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions.SessionState, error) {
|
||||||
|
ss := &sessions.SessionState{}
|
||||||
|
|
||||||
|
if idToken == nil {
|
||||||
|
return ss, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := p.getClaims(idToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ss.User = claims.Subject
|
||||||
|
ss.Email = claims.Email
|
||||||
|
ss.Groups = claims.Groups
|
||||||
|
|
||||||
|
// TODO (@NickMeves) Deprecate for dynamic claim to session mapping
|
||||||
|
if pref, ok := claims.raw["preferred_username"].(string); ok {
|
||||||
|
ss.PreferredUsername = pref
|
||||||
|
}
|
||||||
|
|
||||||
|
// `email_verified` must be present and explicitly set to `false` to be
|
||||||
|
// considered unverified.
|
||||||
|
verifyEmail := (p.EmailClaim == OIDCEmailClaim) && !p.AllowUnverifiedEmail
|
||||||
|
if verifyEmail && claims.Verified != nil && !*claims.Verified {
|
||||||
|
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ss, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getClaims extracts IDToken claims into an OIDCClaims
|
||||||
|
func (p *ProviderData) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) {
|
||||||
|
claims := &OIDCClaims{}
|
||||||
|
|
||||||
|
// Extract default claims.
|
||||||
|
if err := idToken.Claims(&claims); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse default id_token claims: %v", err)
|
||||||
|
}
|
||||||
|
// Extract custom claims.
|
||||||
|
if err := idToken.Claims(&claims.raw); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse all id_token claims: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
email := claims.raw[p.EmailClaim]
|
||||||
|
if email != nil {
|
||||||
|
claims.Email = fmt.Sprint(email)
|
||||||
|
}
|
||||||
|
claims.Groups = p.extractGroups(claims.raw)
|
||||||
|
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractGroups extracts groups from a claim to a list in a type safe manner.
|
||||||
|
// If the claim isn't present, `nil` is returned. If the groups claim is
|
||||||
|
// present but empty, `[]string{}` is returned.
|
||||||
|
func (p *ProviderData) extractGroups(claims map[string]interface{}) []string {
|
||||||
|
rawClaim, ok := claims[p.GroupsClaim]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle traditional list-based groups as well as non-standard singleton
|
||||||
|
// based groups. Both variants support complex objects if needed.
|
||||||
|
var claimGroups []interface{}
|
||||||
|
switch raw := rawClaim.(type) {
|
||||||
|
case []interface{}:
|
||||||
|
claimGroups = raw
|
||||||
|
case interface{}:
|
||||||
|
claimGroups = []interface{}{raw}
|
||||||
|
}
|
||||||
|
|
||||||
|
groups := []string{}
|
||||||
|
for _, rawGroup := range claimGroups {
|
||||||
|
formattedGroup, err := formatGroup(rawGroup)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Warning: unable to format group of type %s with error %s",
|
||||||
|
reflect.TypeOf(rawGroup), err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
groups = append(groups, formattedGroup)
|
||||||
|
}
|
||||||
|
return groups
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,437 @@
|
||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coreos/go-oidc"
|
||||||
|
"github.com/dgrijalva/jwt-go"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
idToken = "eyJfoobar123.eyJbaz987.IDToken"
|
||||||
|
accessToken = "eyJfoobar123.eyJbaz987.AccessToken"
|
||||||
|
refreshToken = "eyJfoobar123.eyJbaz987.RefreshToken"
|
||||||
|
|
||||||
|
oidcIssuer = "https://issuer.example.com"
|
||||||
|
oidcClientID = "https://test.myapp.com"
|
||||||
|
oidcSecret = "SuperSecret123456789"
|
||||||
|
|
||||||
|
failureTokenID = "this-id-fails-verification"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
verified = true
|
||||||
|
unverified = false
|
||||||
|
|
||||||
|
standardClaims = jwt.StandardClaims{
|
||||||
|
Audience: oidcClientID,
|
||||||
|
ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(),
|
||||||
|
Id: "id-some-id",
|
||||||
|
IssuedAt: time.Now().Unix(),
|
||||||
|
Issuer: oidcIssuer,
|
||||||
|
NotBefore: 0,
|
||||||
|
Subject: "123456789",
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultIDToken = idTokenClaims{
|
||||||
|
Name: "Jane Dobbs",
|
||||||
|
Email: "janed@me.com",
|
||||||
|
Phone: "+4798765432",
|
||||||
|
Picture: "http://mugbook.com/janed/me.jpg",
|
||||||
|
Groups: []string{"test:a", "test:b"},
|
||||||
|
Roles: []string{"test:c", "test:d"},
|
||||||
|
Verified: &verified,
|
||||||
|
StandardClaims: standardClaims,
|
||||||
|
}
|
||||||
|
|
||||||
|
complexGroupsIDToken = idTokenClaims{
|
||||||
|
Name: "Complex Claim",
|
||||||
|
Email: "complex@claims.com",
|
||||||
|
Phone: "+5439871234",
|
||||||
|
Picture: "http://mugbook.com/complex/claims.jpg",
|
||||||
|
Groups: []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"groupId": "Admin Group Id",
|
||||||
|
"roles": []string{"Admin"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Roles: []string{"test:simple", "test:roles"},
|
||||||
|
Verified: &verified,
|
||||||
|
StandardClaims: standardClaims,
|
||||||
|
}
|
||||||
|
|
||||||
|
unverifiedIDToken = idTokenClaims{
|
||||||
|
Name: "Mystery Man",
|
||||||
|
Email: "unverified@email.com",
|
||||||
|
Phone: "+4025205729",
|
||||||
|
Picture: "http://mugbook.com/unverified/email.jpg",
|
||||||
|
Groups: []string{"test:a", "test:b"},
|
||||||
|
Roles: []string{"test:c", "test:d"},
|
||||||
|
Verified: &unverified,
|
||||||
|
StandardClaims: standardClaims,
|
||||||
|
}
|
||||||
|
|
||||||
|
minimalIDToken = idTokenClaims{
|
||||||
|
StandardClaims: standardClaims,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
type idTokenClaims struct {
|
||||||
|
Name string `json:"preferred_username,omitempty"`
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
Phone string `json:"phone_number,omitempty"`
|
||||||
|
Picture string `json:"picture,omitempty"`
|
||||||
|
Groups interface{} `json:"groups,omitempty"`
|
||||||
|
Roles interface{} `json:"roles,omitempty"`
|
||||||
|
Verified *bool `json:"email_verified,omitempty"`
|
||||||
|
jwt.StandardClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockJWKS struct{}
|
||||||
|
|
||||||
|
func (mockJWKS) VerifySignature(_ context.Context, jwt string) ([]byte, error) {
|
||||||
|
decoded, err := base64.RawURLEncoding.DecodeString(strings.Split(jwt, ".")[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenClaims := &idTokenClaims{}
|
||||||
|
err = json.Unmarshal(decoded, tokenClaims)
|
||||||
|
if err != nil || tokenClaims.Id == failureTokenID {
|
||||||
|
return nil, fmt.Errorf("the validation failed for subject [%v]", tokenClaims.Subject)
|
||||||
|
}
|
||||||
|
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSignedTestIDToken(tokenClaims idTokenClaims) (string, error) {
|
||||||
|
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
standardClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, tokenClaims)
|
||||||
|
return standardClaims.SignedString(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestOauth2Token() *oauth2.Token {
|
||||||
|
return &oauth2.Token{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
Expiry: time.Time{}.Add(time.Duration(5) * time.Second),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderData_verifyIDToken(t *testing.T) {
|
||||||
|
failureIDToken := defaultIDToken
|
||||||
|
failureIDToken.Id = failureTokenID
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
IDToken *idTokenClaims
|
||||||
|
Verifier bool
|
||||||
|
ExpectIDToken bool
|
||||||
|
ExpectedError error
|
||||||
|
}{
|
||||||
|
"Valid ID Token": {
|
||||||
|
IDToken: &defaultIDToken,
|
||||||
|
Verifier: true,
|
||||||
|
ExpectIDToken: true,
|
||||||
|
ExpectedError: nil,
|
||||||
|
},
|
||||||
|
"Invalid ID Token": {
|
||||||
|
IDToken: &failureIDToken,
|
||||||
|
Verifier: true,
|
||||||
|
ExpectIDToken: false,
|
||||||
|
ExpectedError: errors.New("failed to verify signature: the validation failed for subject [123456789]"),
|
||||||
|
},
|
||||||
|
"Missing ID Token": {
|
||||||
|
IDToken: nil,
|
||||||
|
Verifier: true,
|
||||||
|
ExpectIDToken: false,
|
||||||
|
ExpectedError: ErrMissingIDToken,
|
||||||
|
},
|
||||||
|
"OIDC Verifier not Configured": {
|
||||||
|
IDToken: &defaultIDToken,
|
||||||
|
Verifier: false,
|
||||||
|
ExpectIDToken: false,
|
||||||
|
ExpectedError: ErrMissingOIDCVerifier,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for testName, tc := range testCases {
|
||||||
|
t.Run(testName, func(t *testing.T) {
|
||||||
|
g := NewWithT(t)
|
||||||
|
|
||||||
|
token := newTestOauth2Token()
|
||||||
|
if tc.IDToken != nil {
|
||||||
|
idToken, err := newSignedTestIDToken(*tc.IDToken)
|
||||||
|
g.Expect(err).ToNot(HaveOccurred())
|
||||||
|
token = token.WithExtra(map[string]interface{}{
|
||||||
|
"id_token": idToken,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := &ProviderData{}
|
||||||
|
if tc.Verifier {
|
||||||
|
provider.Verifier = oidc.NewVerifier(
|
||||||
|
oidcIssuer,
|
||||||
|
mockJWKS{},
|
||||||
|
&oidc.Config{ClientID: oidcClientID},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
verified, err := provider.verifyIDToken(context.Background(), token)
|
||||||
|
if err != nil {
|
||||||
|
g.Expect(err).To(Equal(tc.ExpectedError))
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.ExpectIDToken {
|
||||||
|
g.Expect(verified).ToNot(BeNil())
|
||||||
|
g.Expect(*verified).To(BeAssignableToTypeOf(oidc.IDToken{}))
|
||||||
|
} else {
|
||||||
|
g.Expect(verified).To(BeNil())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
IDToken idTokenClaims
|
||||||
|
AllowUnverified bool
|
||||||
|
EmailClaim string
|
||||||
|
GroupsClaim string
|
||||||
|
ExpectedError error
|
||||||
|
ExpectedSession *sessions.SessionState
|
||||||
|
}{
|
||||||
|
"Standard": {
|
||||||
|
IDToken: defaultIDToken,
|
||||||
|
AllowUnverified: false,
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "123456789",
|
||||||
|
Email: "janed@me.com",
|
||||||
|
Groups: []string{"test:a", "test:b"},
|
||||||
|
PreferredUsername: "Jane Dobbs",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Unverified Denied": {
|
||||||
|
IDToken: unverifiedIDToken,
|
||||||
|
AllowUnverified: false,
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ExpectedError: errors.New("email in id_token (unverified@email.com) isn't verified"),
|
||||||
|
},
|
||||||
|
"Unverified Allowed": {
|
||||||
|
IDToken: unverifiedIDToken,
|
||||||
|
AllowUnverified: true,
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "123456789",
|
||||||
|
Email: "unverified@email.com",
|
||||||
|
Groups: []string{"test:a", "test:b"},
|
||||||
|
PreferredUsername: "Mystery Man",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Complex Groups": {
|
||||||
|
IDToken: complexGroupsIDToken,
|
||||||
|
AllowUnverified: true,
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "123456789",
|
||||||
|
Email: "complex@claims.com",
|
||||||
|
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||||
|
PreferredUsername: "Complex Claim",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Email Claim Switched": {
|
||||||
|
IDToken: unverifiedIDToken,
|
||||||
|
AllowUnverified: true,
|
||||||
|
EmailClaim: "phone_number",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "123456789",
|
||||||
|
Email: "+4025205729",
|
||||||
|
Groups: []string{"test:a", "test:b"},
|
||||||
|
PreferredUsername: "Mystery Man",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Email Claim Switched to Non String": {
|
||||||
|
IDToken: unverifiedIDToken,
|
||||||
|
AllowUnverified: true,
|
||||||
|
EmailClaim: "roles",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "123456789",
|
||||||
|
Email: "[test:c test:d]",
|
||||||
|
Groups: []string{"test:a", "test:b"},
|
||||||
|
PreferredUsername: "Mystery Man",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Email Claim Non Existent": {
|
||||||
|
IDToken: unverifiedIDToken,
|
||||||
|
AllowUnverified: true,
|
||||||
|
EmailClaim: "aksjdfhjksadh",
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "123456789",
|
||||||
|
Email: "",
|
||||||
|
Groups: []string{"test:a", "test:b"},
|
||||||
|
PreferredUsername: "Mystery Man",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Groups Claim Switched": {
|
||||||
|
IDToken: defaultIDToken,
|
||||||
|
AllowUnverified: false,
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "roles",
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "123456789",
|
||||||
|
Email: "janed@me.com",
|
||||||
|
Groups: []string{"test:c", "test:d"},
|
||||||
|
PreferredUsername: "Jane Dobbs",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Groups Claim Non Existent": {
|
||||||
|
IDToken: defaultIDToken,
|
||||||
|
AllowUnverified: false,
|
||||||
|
EmailClaim: "email",
|
||||||
|
GroupsClaim: "alskdjfsalkdjf",
|
||||||
|
ExpectedSession: &sessions.SessionState{
|
||||||
|
User: "123456789",
|
||||||
|
Email: "janed@me.com",
|
||||||
|
Groups: nil,
|
||||||
|
PreferredUsername: "Jane Dobbs",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for testName, tc := range testCases {
|
||||||
|
t.Run(testName, func(t *testing.T) {
|
||||||
|
g := NewWithT(t)
|
||||||
|
|
||||||
|
provider := &ProviderData{
|
||||||
|
Verifier: oidc.NewVerifier(
|
||||||
|
oidcIssuer,
|
||||||
|
mockJWKS{},
|
||||||
|
&oidc.Config{ClientID: oidcClientID},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
provider.AllowUnverifiedEmail = tc.AllowUnverified
|
||||||
|
provider.EmailClaim = tc.EmailClaim
|
||||||
|
provider.GroupsClaim = tc.GroupsClaim
|
||||||
|
|
||||||
|
rawIDToken, err := newSignedTestIDToken(tc.IDToken)
|
||||||
|
g.Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken)
|
||||||
|
g.Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
ss, err := provider.buildSessionFromClaims(idToken)
|
||||||
|
if err != nil {
|
||||||
|
g.Expect(err).To(Equal(tc.ExpectedError))
|
||||||
|
}
|
||||||
|
if ss != nil {
|
||||||
|
g.Expect(ss).To(Equal(tc.ExpectedSession))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderData_extractGroups(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
Claims map[string]interface{}
|
||||||
|
GroupsClaim string
|
||||||
|
ExpectedGroups []string
|
||||||
|
}{
|
||||||
|
"Standard String Groups": {
|
||||||
|
Claims: map[string]interface{}{
|
||||||
|
"email": "this@does.not.matter.com",
|
||||||
|
"groups": []interface{}{"three", "string", "groups"},
|
||||||
|
},
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ExpectedGroups: []string{"three", "string", "groups"},
|
||||||
|
},
|
||||||
|
"Different Claim Name": {
|
||||||
|
Claims: map[string]interface{}{
|
||||||
|
"email": "this@does.not.matter.com",
|
||||||
|
"roles": []interface{}{"three", "string", "roles"},
|
||||||
|
},
|
||||||
|
GroupsClaim: "roles",
|
||||||
|
ExpectedGroups: []string{"three", "string", "roles"},
|
||||||
|
},
|
||||||
|
"Numeric Groups": {
|
||||||
|
Claims: map[string]interface{}{
|
||||||
|
"email": "this@does.not.matter.com",
|
||||||
|
"groups": []interface{}{1, 2, 3},
|
||||||
|
},
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ExpectedGroups: []string{"1", "2", "3"},
|
||||||
|
},
|
||||||
|
"Complex Groups": {
|
||||||
|
Claims: map[string]interface{}{
|
||||||
|
"email": "this@does.not.matter.com",
|
||||||
|
"groups": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"groupId": "Admin Group Id",
|
||||||
|
"roles": []string{"Admin"},
|
||||||
|
},
|
||||||
|
12345,
|
||||||
|
"Just::A::String",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ExpectedGroups: []string{
|
||||||
|
"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}",
|
||||||
|
"12345",
|
||||||
|
"Just::A::String",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Missing Groups Claim Returns Nil": {
|
||||||
|
Claims: map[string]interface{}{
|
||||||
|
"email": "this@does.not.matter.com",
|
||||||
|
},
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ExpectedGroups: nil,
|
||||||
|
},
|
||||||
|
"Non List Groups": {
|
||||||
|
Claims: map[string]interface{}{
|
||||||
|
"email": "this@does.not.matter.com",
|
||||||
|
"groups": "singleton",
|
||||||
|
},
|
||||||
|
GroupsClaim: "groups",
|
||||||
|
ExpectedGroups: []string{"singleton"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for testName, tc := range testCases {
|
||||||
|
t.Run(testName, func(t *testing.T) {
|
||||||
|
g := NewWithT(t)
|
||||||
|
|
||||||
|
provider := &ProviderData{
|
||||||
|
Verifier: oidc.NewVerifier(
|
||||||
|
oidcIssuer,
|
||||||
|
mockJWKS{},
|
||||||
|
&oidc.Config{ClientID: oidcClientID},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
provider.GroupsClaim = tc.GroupsClaim
|
||||||
|
|
||||||
|
groups := provider.extractGroups(tc.Claims)
|
||||||
|
if tc.ExpectedGroups != nil {
|
||||||
|
g.Expect(groups).To(Equal(tc.ExpectedGroups))
|
||||||
|
} else {
|
||||||
|
g.Expect(groups).To(BeNil())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -22,6 +22,14 @@ var (
|
||||||
// code
|
// code
|
||||||
ErrMissingCode = errors.New("missing code")
|
ErrMissingCode = errors.New("missing code")
|
||||||
|
|
||||||
|
// ErrMissingIDToken is returned when an oidc.Token does not contain the
|
||||||
|
// extra `id_token` field for an IDToken.
|
||||||
|
ErrMissingIDToken = errors.New("missing id_token")
|
||||||
|
|
||||||
|
// ErrMissingOIDCVerifier is returned when a provider didn't set `Verifier`
|
||||||
|
// but an attempt to call `Verifier.Verify` was about to be made.
|
||||||
|
ErrMissingOIDCVerifier = errors.New("oidc verifier is not configured")
|
||||||
|
|
||||||
_ Provider = (*ProviderData)(nil)
|
_ Provider = (*ProviderData)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,13 @@
|
||||||
package providers
|
package providers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/bitly/go-simplejson"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
@ -55,3 +59,42 @@ func makeLoginURL(p *ProviderData, redirectURI, state string, extraParams url.Va
|
||||||
a.RawQuery = params.Encode()
|
a.RawQuery = params.Encode()
|
||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getIDToken extracts an IDToken stored in the `Extra` fields of an
|
||||||
|
// oauth2.Token
|
||||||
|
func getIDToken(token *oauth2.Token) string {
|
||||||
|
idToken, ok := token.Extra("id_token").(string)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return idToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatGroup coerces an OIDC groups claim into a string
|
||||||
|
// If it is non-string, marshal it into JSON.
|
||||||
|
func formatGroup(rawGroup interface{}) (string, error) {
|
||||||
|
if group, ok := rawGroup.(string); ok {
|
||||||
|
return group, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonGroup, err := json.Marshal(rawGroup)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(jsonGroup), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// coerceArray extracts a field from simplejson.Json that might be a
|
||||||
|
// singleton or a list and coerces it into a list.
|
||||||
|
func coerceArray(sj *simplejson.Json, key string) []interface{} {
|
||||||
|
array, err := sj.Get(key).Array()
|
||||||
|
if err == nil {
|
||||||
|
return array
|
||||||
|
}
|
||||||
|
|
||||||
|
single := sj.Get(key).Interface()
|
||||||
|
if single == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []interface{}{single}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,10 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMakeAuhtorizationHeader(t *testing.T) {
|
func Test_makeAuthorizationHeader(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
prefix string
|
prefix string
|
||||||
|
|
@ -64,3 +65,49 @@ func TestMakeAuhtorizationHeader(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_getIDToken(t *testing.T) {
|
||||||
|
const idToken = "eyJfoobar.eyJfoobar.12345asdf"
|
||||||
|
g := NewWithT(t)
|
||||||
|
|
||||||
|
token := &oauth2.Token{}
|
||||||
|
g.Expect(getIDToken(token)).To(Equal(""))
|
||||||
|
|
||||||
|
extraToken := token.WithExtra(map[string]interface{}{
|
||||||
|
"id_token": idToken,
|
||||||
|
})
|
||||||
|
g.Expect(getIDToken(extraToken)).To(Equal(idToken))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_formatGroup(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
rawGroup interface{}
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
"String Group": {
|
||||||
|
rawGroup: "group",
|
||||||
|
expected: "group",
|
||||||
|
},
|
||||||
|
"Numeric Group": {
|
||||||
|
rawGroup: 123,
|
||||||
|
expected: "123",
|
||||||
|
},
|
||||||
|
"Map Group": {
|
||||||
|
rawGroup: map[string]string{"id": "1", "name": "Test"},
|
||||||
|
expected: "{\"id\":\"1\",\"name\":\"Test\"}",
|
||||||
|
},
|
||||||
|
"List Group": {
|
||||||
|
rawGroup: []string{"First", "Second"},
|
||||||
|
expected: "[\"First\",\"Second\"]",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for testName, tc := range testCases {
|
||||||
|
t.Run(testName, func(t *testing.T) {
|
||||||
|
g := NewWithT(t)
|
||||||
|
formattedGroup, err := formatGroup(tc.rawGroup)
|
||||||
|
g.Expect(err).ToNot(HaveOccurred())
|
||||||
|
g.Expect(formattedGroup).To(Equal(tc.expected))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue