Merge remote-tracking branch 'origin2/master'
This commit is contained in:
commit
e7919f0535
35
CHANGELOG.md
35
CHANGELOG.md
|
|
@ -4,7 +4,21 @@
|
|||
|
||||
## Important Notes
|
||||
|
||||
- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) Redirect URL generation will attempt secondary strategies
|
||||
in the priority chain if any fail the `IsValidRedirect` security check. Previously any failures fell back to `/`.
|
||||
- [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Keycloak will now use `--profile-url` if set for the userinfo endpoint
|
||||
instead of `--validate-url`. `--validate-url` will still work for backwards compatibility.
|
||||
- [#957](https://github.com/oauth2-proxy/oauth2-proxy/pull/957) To use X-Forwarded-{Proto,Host,Uri} on redirect detection, `--reverse-proxy` must be `true`.
|
||||
- [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) `--user-id-claim` option is deprecated and replaced by `--oidc-email-claim`
|
||||
- [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Gitlab projects needs a Gitlab application with the extra `read_api` enabled
|
||||
- [#849](https://github.com/oauth2-proxy/oauth2-proxy/pull/849) `/oauth2/auth` `allowed_groups` querystring parameter can be paired with the `allowed-groups` configuration option.
|
||||
- The `allowed_groups` querystring parameter can specify multiple comma delimited groups.
|
||||
- In this scenario, the user must have a group (from their multiple groups) present in both lists to not get a 401 or 403 response code.
|
||||
- Example:
|
||||
- OAuth2-Proxy globally sets the `allowed_groups` as `engineering`.
|
||||
- An application using Kubernetes ingress uses the `/oauth2/auth` endpoint with `allowed_groups` querystring set to `backend`.
|
||||
- A user must have a session with the groups `["engineering", "backend"]` to pass authorization.
|
||||
- Another user with the groups `["engineering", "frontend"]` would fail the querystring authorization portion.
|
||||
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication.
|
||||
- [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped.
|
||||
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) The behavior of the Google provider Groups restriction changes with this
|
||||
|
|
@ -18,11 +32,19 @@
|
|||
- [#575](https://github.com/oauth2-proxy/oauth2-proxy/pull/575) Sessions from v5.1.1 or earlier will no longer validate since they were not signed with SHA1.
|
||||
- Sessions from v6.0.0 or later had a graceful conversion to SHA256 that resulted in no reauthentication
|
||||
- Upgrading from v5.1.1 or earlier will result in a reauthentication
|
||||
- [#616](https://github.com/oauth2-proxy/oauth2-proxy/pull/616) Ensure you have configured oauth2-proxy to use the `groups` scope. The user may be logged out initially as they may not currently have the `groups` claim however after going back through login process wil be authenticated.
|
||||
- [#616](https://github.com/oauth2-proxy/oauth2-proxy/pull/616) Ensure you have configured oauth2-proxy to use the `groups` scope.
|
||||
- The user may be logged out initially as they may not currently have the `groups` claim however after going back through login process wil be authenticated.
|
||||
- [#839](https://github.com/oauth2-proxy/oauth2-proxy/pull/839) Enables complex data structures for group claim entries, which are output as Json by default.
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) `--reverse-proxy` must be true to trust `X-Forwarded-*` headers as canonical.
|
||||
These are used throughout the application in redirect URLs, cookie domains and host logging logic. These are the headers:
|
||||
- `X-Forwarded-Proto` instead of `req.URL.Scheme`
|
||||
- `X-Forwarded-Host` instead of `req.Host`
|
||||
- `X-Forwarded-Uri` instead of `req.URL.RequestURI()`
|
||||
- [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) In config files & envvar configs, `keycloak_group` is now the plural `keycloak_groups`.
|
||||
Flag configs are still `--keycloak-group` but it can be passed multiple times.
|
||||
- [#911](https://github.com/oauth2-proxy/oauth2-proxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google".
|
||||
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Security changes to Google provider group authorization flow
|
||||
- If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately.
|
||||
|
|
@ -44,15 +66,23 @@
|
|||
|
||||
## Changes since v6.1.1
|
||||
|
||||
- [#995](https://github.com/oauth2-proxy/oauth2-proxy/pull/995) Add Security Policy (@JoelSpeed)
|
||||
- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) Require `--reverse-proxy` true to trust `X-Forwareded-*` type headers (@NickMeves)
|
||||
- [#970](https://github.com/oauth2-proxy/oauth2-proxy/pull/970) Fix joined cookie name for those containing underline in the suffix (@peppered)
|
||||
- [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Migrate Keycloak to EnrichSession & support multiple groups for authorization (@NickMeves)
|
||||
- [#957](https://github.com/oauth2-proxy/oauth2-proxy/pull/957) Use X-Forwarded-{Proto,Host,Uri} on redirect as last resort (@linuxgemini)
|
||||
- [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Add support for Gitlab project based authentication (@factorysh)
|
||||
- [#907](https://github.com/oauth2-proxy/oauth2-proxy/pull/907) Introduce alpha configuration option to enable testing of structured configuration (@JoelSpeed)
|
||||
- [#938](https://github.com/oauth2-proxy/oauth2-proxy/pull/938) Cleanup missed provider renaming refactor methods (@NickMeves)
|
||||
- [#816](https://github.com/oauth2-proxy/oauth2-proxy/pull/816) (via [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936)) Support non-list group claims (@loafoe)
|
||||
- [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) Refactor OIDC Provider and support groups from Profile URL (@NickMeves)
|
||||
- [#869](https://github.com/oauth2-proxy/oauth2-proxy/pull/869) Streamline provider interface method names and signatures (@NickMeves)
|
||||
- [#849](https://github.com/oauth2-proxy/oauth2-proxy/pull/849) Support group authorization on `oauth2/auth` endpoint via `allowed_groups` querystring (@NickMeves)
|
||||
- [#925](https://github.com/oauth2-proxy/oauth2-proxy/pull/925) Fix basic auth legacy header conversion (@JoelSpeed)
|
||||
- [#916](https://github.com/oauth2-proxy/oauth2-proxy/pull/916) Add AlphaOptions struct to prepare for alpha config loading (@JoelSpeed)
|
||||
- [#923](https://github.com/oauth2-proxy/oauth2-proxy/pull/923) Support TLS 1.3 (@aajisaka)
|
||||
- [#918](https://github.com/oauth2-proxy/oauth2-proxy/pull/918) Fix log header output (@JoelSpeed)
|
||||
- [#911](https://github.com/oauth2-proxy/oauth2-proxy/pull/911) Validate provider type on startup.
|
||||
- [#869](https://github.com/oauth2-proxy/oauth2-proxy/pull/869) Streamline provider interface method names and signatures (@NickMeves)
|
||||
- [#906](https://github.com/oauth2-proxy/oauth2-proxy/pull/906) Set up v6.1.x versioned documentation as default documentation (@JoelSpeed)
|
||||
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Remove v5 legacy sessions support (@NickMeves)
|
||||
- [#904](https://github.com/oauth2-proxy/oauth2-proxy/pull/904) Set `skip-auth-strip-headers` to `true` by default (@NickMeves)
|
||||
|
|
@ -79,6 +109,7 @@
|
|||
- [#750](https://github.com/oauth2-proxy/oauth2-proxy/pull/750) ci: Migrate to Github Actions (@shinebayar-g)
|
||||
- [#829](https://github.com/oauth2-proxy/oauth2-proxy/pull/820) Rename test directory to testdata (@johejo)
|
||||
- [#819](https://github.com/oauth2-proxy/oauth2-proxy/pull/819) Improve CI (@johejo)
|
||||
- [#989](https://github.com/oauth2-proxy/oauth2-proxy/pull/989) Adapt isAjax to support mimetype lists (@rassie)
|
||||
|
||||
# v6.1.1
|
||||
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
Joel Speed <joel.speed@hotmail.co.uk> (@JoelSpeed)
|
||||
Henry Jenkins <henry@henryjenkins.name> (@steakunderscore)
|
||||
Nick Meves <nick.meves@greenhouse.io> (@NickMeves)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
# Security Disclosures
|
||||
|
||||
Please see [our community docs](https://oauth2-proxy.github.io/oauth2-proxy/docs/community/security) for our security policy.
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
---
|
||||
id: security
|
||||
title: Security
|
||||
---
|
||||
|
||||
:::note
|
||||
OAuth2 Proxy is a community project.
|
||||
Maintainers do not work on this project full time, and as such,
|
||||
while we endeavour to respond to disclosures as quickly as possible,
|
||||
this may take longer than in projects with corporate sponsorship.
|
||||
:::
|
||||
|
||||
## Security Disclosures
|
||||
|
||||
:::important
|
||||
If you believe you have found a vulnerability within OAuth2 Proxy or any of its
|
||||
dependencies, please do NOT open an issue or PR on GitHub, please do NOT post
|
||||
any details publicly.
|
||||
:::
|
||||
|
||||
Security disclosures MUST be done in private.
|
||||
If you have found an issue that you would like to bring to the attention of the
|
||||
maintenance team for OAuth2 Proxy, please compose an email and send it to the
|
||||
list of maintainers in our [MAINTAINERS](https://github.com/oauth2-proxy/oauth2-proxy/blob/master/MAINTAINERS) file.
|
||||
|
||||
Please include as much detail as possible.
|
||||
Ideally, your disclosure should include:
|
||||
- A reproducible case that can be used to demonstrate the exploit
|
||||
- How you discovered this vulnerability
|
||||
- A potential fix for the issue (if you have thought of one)
|
||||
- Versions affected (if not present in master)
|
||||
- Your GitHub ID
|
||||
|
||||
### How will we respond to disclosures?
|
||||
|
||||
We use [GitHub Security Advisories](https://docs.github.com/en/github/managing-security-vulnerabilities/about-github-security-advisories)
|
||||
to privately discuss fixes for disclosed vulnerabilities.
|
||||
If you include a GitHub ID with your disclosure we will add you as a collaborator
|
||||
for the advisory so that you can join the discussion and validate any fixes
|
||||
we may propose.
|
||||
|
||||
For minor issues and previously disclosed vulnerabilities (typically for
|
||||
dependencies), we may use regular PRs for fixes and forego the security advisory.
|
||||
|
||||
Once a fix has been agreed upon, we will merge the fix and create a new release.
|
||||
If we have multiple security issues in flight simultaneously, we may delay
|
||||
merging fixes until all patches are ready.
|
||||
We may also backport the fix to previous releases,
|
||||
but this will be at the discretion of the maintainers.
|
||||
|
|
@ -135,15 +135,25 @@ If you are using GitHub enterprise, make sure you set the following to the appro
|
|||
|
||||
Make sure you set the following to the appropriate url:
|
||||
|
||||
-provider=keycloak
|
||||
-client-id=<client you have created>
|
||||
-client-secret=<your client's secret>
|
||||
-login-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/auth"
|
||||
-redeem-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/token"
|
||||
-validate-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/userinfo"
|
||||
-keycloak-group=<user_group>
|
||||
--provider=keycloak
|
||||
--client-id=<client you have created>
|
||||
--client-secret=<your client's secret>
|
||||
--login-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/auth"
|
||||
--redeem-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/token"
|
||||
--profile-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/userinfo"
|
||||
--validate-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/userinfo"
|
||||
--keycloak-group=<first_allowed_user_group>
|
||||
--keycloak-group=<second_allowed_user_group>
|
||||
|
||||
The group management in keycloak is using a tree. If you create a group named admin in keycloak you should define the 'keycloak-group' value to /admin.
|
||||
For group based authorization, the optional `--keycloak-group` (legacy) or `--allowed-group` (global standard)
|
||||
flags can be used to specify which groups to limit access to.
|
||||
|
||||
If these are unset but a `groups` mapper is set up above in step (3), the provider will still
|
||||
populate the `X-Forwarded-Groups` header to your upstream server with the `groups` data in the
|
||||
Keycloak userinfo endpoint response.
|
||||
|
||||
The group management in keycloak is using a tree. If you create a group named admin in keycloak
|
||||
you should define the 'keycloak-group' value to /admin.
|
||||
|
||||
### GitLab Auth Provider
|
||||
|
||||
|
|
|
|||
|
|
@ -74,7 +74,8 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/
|
|||
| `--insecure-oidc-skip-issuer-verification` | bool | allow the OIDC issuer URL to differ from the expected (currently required for Azure multi-tenant compatibility) | false |
|
||||
| `--oidc-issuer-url` | string | the OpenID Connect issuer URL, e.g. `"https://accounts.google.com"` | |
|
||||
| `--oidc-jwks-url` | string | OIDC JWKS URI for token verification; required if OIDC discovery is disabled | |
|
||||
| `--oidc-groups-claim` | string | which claim contains the user groups | `"groups"` |
|
||||
| `--oidc-email-claim` | string | which OIDC claim contains the user's email | `"email"` |
|
||||
| `--oidc-groups-claim` | string | which OIDC claim contains the user groups | `"groups"` |
|
||||
| `--pass-access-token` | bool | pass OAuth access_token to upstream via X-Forwarded-Access-Token header. When used with `--set-xauthrequest` this adds the X-Auth-Request-Access-Token header to the response | false |
|
||||
| `--pass-authorization-header` | bool | pass OIDC IDToken to upstream via Authorization Bearer header | false |
|
||||
| `--pass-basic-auth` | bool | pass HTTP Basic Auth, X-Forwarded-User, X-Forwarded-Email and X-Forwarded-Preferred-Username information to upstream | true |
|
||||
|
|
@ -105,7 +106,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/
|
|||
| `--request-logging` | bool | Log requests | true |
|
||||
| `--request-logging-format` | string | Template for request log lines | see [Logging Configuration](#logging-configuration) |
|
||||
| `--resource` | string | The resource that is protected (Azure AD only) | |
|
||||
| `--reverse-proxy` | bool | are we running behind a reverse proxy, controls whether headers like X-Real-IP are accepted | false |
|
||||
| `--reverse-proxy` | bool | are we running behind a reverse proxy, controls whether headers like X-Real-IP are accepted and allows X-Forwarded-{Proto,Host,Uri} headers to be used on redirect selection | false |
|
||||
| `--scope` | string | OAuth scope specification | |
|
||||
| `--session-cookie-minimal` | bool | strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only) | false |
|
||||
| `--session-store-type` | string | [Session data storage backend](sessions.md); redis or cookie | cookie |
|
||||
|
|
@ -128,7 +129,6 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/
|
|||
| `--tls-cert-file` | string | path to certificate file | |
|
||||
| `--tls-key-file` | string | path to private key file | |
|
||||
| `--upstream` | string \| list | the http url(s) of the upstream endpoint, file:// paths for static files or `static://<status_code>` for static response. Routing is based on the path | |
|
||||
| `--user-id-claim` | string | which claim contains the user ID | \["email"\] |
|
||||
| `--allowed-group` | string \| list | restrict logins to members of this group (may be given multiple times) | |
|
||||
| `--validate-url` | string | Access token validation endpoint | |
|
||||
| `--version` | n/a | print version string | |
|
||||
|
|
@ -354,6 +354,73 @@ It is recommended to use `--session-store-type=redis` when expecting large sessi
|
|||
|
||||
You have to substitute *name* with the actual cookie name you configured via --cookie-name parameter. If you don't set a custom cookie name the variable should be "$upstream_cookie__oauth2_proxy_1" instead of "$upstream_cookie_name_1" and the new cookie-name should be "_oauth2_proxy_1=" instead of "name_1=".
|
||||
|
||||
## Configuring for use with the Traefik (v2) `ForwardAuth` middleware
|
||||
|
||||
**This option requires `--reverse-proxy` option to be set.**
|
||||
|
||||
The [Traefik v2 `ForwardAuth` middleware](https://doc.traefik.io/traefik/middlewares/forwardauth/) allows Traefik to authenticate requests via the oauth2-proxy's `/oauth2/auth` endpoint on every request, which only returns a 202 Accepted response or a 401 Unauthorized response without proxying the whole request through. For example, on Dynamic File (YAML) Configuration:
|
||||
|
||||
```yaml
|
||||
http:
|
||||
routers:
|
||||
a-service:
|
||||
rule: "Host(`a-service.example.com`)"
|
||||
service: a-service-backend
|
||||
middlewares:
|
||||
- oauth-errors
|
||||
- oauth-auth
|
||||
tls:
|
||||
certResolver: default
|
||||
domains:
|
||||
- main: "example.com"
|
||||
sans:
|
||||
- "*.example.com"
|
||||
oauth:
|
||||
rule: "Host(`a-service.example.com`, `oauth.example.com`) && PathPrefix(`/oauth2/`)"
|
||||
middlewares:
|
||||
- auth-headers
|
||||
service: oauth-backend
|
||||
tls:
|
||||
certResolver: default
|
||||
domains:
|
||||
- main: "example.com"
|
||||
sans:
|
||||
- "*.example.com"
|
||||
|
||||
services:
|
||||
a-service-backend:
|
||||
loadBalancer:
|
||||
servers:
|
||||
- url: http://172.16.0.2:7555
|
||||
oauth-backend:
|
||||
loadBalancer:
|
||||
servers:
|
||||
- url: http://172.16.0.1:4180
|
||||
|
||||
middlewares:
|
||||
auth-headers:
|
||||
headers:
|
||||
sslRedirect: true
|
||||
stsSeconds: 315360000
|
||||
browserXssFilter: true
|
||||
contentTypeNosniff: true
|
||||
forceSTSHeader: true
|
||||
sslHost: example.com
|
||||
stsIncludeSubdomains: true
|
||||
stsPreload: true
|
||||
frameDeny: true
|
||||
oauth-auth:
|
||||
forwardAuth:
|
||||
address: https://oauth.example.com/oauth2/auth
|
||||
trustForwardHeader: true
|
||||
oauth-errors:
|
||||
errors:
|
||||
status:
|
||||
- "401-403"
|
||||
service: oauth-backend
|
||||
query: "/oauth2/sign_in"
|
||||
```
|
||||
|
||||
:::note
|
||||
If you set up your OAuth2 provider to rotate your client secret, you can use the `client-secret-file` option to reload the secret when it is updated.
|
||||
:::
|
||||
|
|
|
|||
|
|
@ -20,5 +20,11 @@ module.exports = {
|
|||
collapsed: false,
|
||||
items: ['features/endpoints', 'features/request_signatures'],
|
||||
},
|
||||
{
|
||||
type: 'category',
|
||||
label: 'Community',
|
||||
collapsed: false,
|
||||
items: ['community/security'],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
---
|
||||
id: security
|
||||
title: Security
|
||||
---
|
||||
|
||||
:::note
|
||||
OAuth2 Proxy is a community project.
|
||||
Maintainers do not work on this project full time, and as such,
|
||||
while we endeavour to respond to disclosures as quickly as possible,
|
||||
this may take longer than in projects with corporate sponsorship.
|
||||
:::
|
||||
|
||||
## Security Disclosures
|
||||
|
||||
:::important
|
||||
If you believe you have found a vulnerability within OAuth2 Proxy or any of its
|
||||
dependencies, please do NOT open an issue or PR on GitHub, please do NOT post any
|
||||
details publicly.
|
||||
:::
|
||||
|
||||
Security disclosures MUST be done in private.
|
||||
If you have found an issue that you would like to bring to the attention of the
|
||||
maintenance team for OAuth2 Proxy, please compose an email and send it to the
|
||||
list of maintainers in our [MAINTAINERS](https://github.com/oauth2-proxy/oauth2-proxy/blob/master/MAINTAINERS) file.
|
||||
|
||||
Please include as much detail as possible.
|
||||
Ideally, your disclosure should include:
|
||||
- A reproducible case that can be used to demonstrate the exploit
|
||||
- How you discovered this vulnerability
|
||||
- A potential fix for the issue (if you have thought of one)
|
||||
- Versions affected (if not present in master)
|
||||
- Your GitHub ID
|
||||
|
||||
### How will we respond to disclosures?
|
||||
|
||||
We use [GitHub Security Advisories](https://docs.github.com/en/github/managing-security-vulnerabilities/about-github-security-advisories)
|
||||
to privately discuss fixes for disclosed vulnerabilities.
|
||||
If you include a GitHub ID with your disclosure we will add you as a collaborator
|
||||
for the advisory so that you can join the discussion and validate any fixes
|
||||
we may propose.
|
||||
|
||||
For minor issues and previously disclosed vulnerabilities (typically for
|
||||
dependencies), we may use regular PRs for fixes and forego the security advisory.
|
||||
|
||||
Once a fix has been agreed upon, we will merge the fix and create a new release.
|
||||
If we have multiple security issues in flight simultaneously, we may delay
|
||||
merging fixes until all patches are ready.
|
||||
We may also backport the fix to previous releases,
|
||||
but this will be at the discretion of the maintainers.
|
||||
|
|
@ -45,6 +45,17 @@
|
|||
"id": "version-6.1.x/features/request_signatures"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"collapsed": false,
|
||||
"type": "category",
|
||||
"label": "Community",
|
||||
"items": [
|
||||
{
|
||||
"type": "doc",
|
||||
"id": "version-6.1.x/community/security"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
663
oauthproxy.go
663
oauthproxy.go
|
|
@ -24,16 +24,14 @@ import (
|
|||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
|
||||
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/upstream"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
|
||||
)
|
||||
|
||||
const (
|
||||
httpScheme = "http"
|
||||
httpsScheme = "https"
|
||||
|
||||
schemeHTTPS = "https"
|
||||
applicationJSON = "application/json"
|
||||
)
|
||||
|
||||
|
|
@ -229,7 +227,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
|||
// the OAuth2 Proxy authentication logic kicks in.
|
||||
// For example forcing HTTPS or health checks.
|
||||
func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
|
||||
chain := alice.New(middleware.NewScope())
|
||||
chain := alice.New(middleware.NewScope(opts.ReverseProxy))
|
||||
|
||||
if opts.ForceHTTPS {
|
||||
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
|
||||
|
|
@ -366,49 +364,6 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) {
|
|||
return routes, nil
|
||||
}
|
||||
|
||||
// GetRedirectURI returns the redirectURL that the upstream OAuth Provider will
|
||||
// redirect clients to once authenticated
|
||||
func (p *OAuthProxy) GetRedirectURI(host string) string {
|
||||
// default to the request Host if not set
|
||||
if p.redirectURL.Host != "" {
|
||||
return p.redirectURL.String()
|
||||
}
|
||||
u := *p.redirectURL
|
||||
if u.Scheme == "" {
|
||||
if p.CookieSecure {
|
||||
u.Scheme = httpsScheme
|
||||
} else {
|
||||
u.Scheme = httpScheme
|
||||
}
|
||||
}
|
||||
u.Host = host
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) {
|
||||
if code == "" {
|
||||
return nil, providers.ErrMissingCode
|
||||
}
|
||||
redirectURI := p.GetRedirectURI(host)
|
||||
s, err := p.provider.Redeem(ctx, redirectURI, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error {
|
||||
var err error
|
||||
if s.Email == "" {
|
||||
s.Email, err = p.provider.GetEmailAddress(ctx, s)
|
||||
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return p.provider.EnrichSession(ctx, s)
|
||||
}
|
||||
|
||||
// MakeCSRFCookie creates a cookie for CSRF
|
||||
func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return p.makeCookie(req, p.CSRFCookieName, value, expiration, now)
|
||||
|
|
@ -418,7 +373,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex
|
|||
cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains)
|
||||
|
||||
if cookieDomain != "" {
|
||||
domain := util.GetRequestHost(req)
|
||||
domain := requestutil.GetRequestHost(req)
|
||||
if h, _, err := net.SplitHostPort(domain); err == nil {
|
||||
domain = h
|
||||
}
|
||||
|
|
@ -466,6 +421,81 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *s
|
|||
return p.sessionStore.Save(rw, req, s)
|
||||
}
|
||||
|
||||
// IsValidRedirect checks whether the redirect URL is whitelisted
|
||||
func (p *OAuthProxy) IsValidRedirect(redirect string) bool {
|
||||
switch {
|
||||
case redirect == "":
|
||||
// The user didn't specify a redirect, should fallback to `/`
|
||||
return false
|
||||
case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect):
|
||||
return true
|
||||
case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"):
|
||||
redirectURL, err := url.Parse(redirect)
|
||||
if err != nil {
|
||||
logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect)
|
||||
return false
|
||||
}
|
||||
redirectHostname := redirectURL.Hostname()
|
||||
|
||||
for _, domain := range p.whitelistDomains {
|
||||
domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, "."))
|
||||
if domainHostname == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) {
|
||||
// the domain names match, now validate the ports
|
||||
// if the whitelisted domain's port is '*', allow all ports
|
||||
// if the whitelisted domain contains a specific port, only allow that port
|
||||
// if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https
|
||||
redirectPort := redirectURL.Port()
|
||||
if (domainPort == "*") ||
|
||||
(domainPort == redirectPort) ||
|
||||
(domainPort == "" && redirectPort == "") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect)
|
||||
return false
|
||||
default:
|
||||
logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) {
|
||||
prepareNoCache(rw)
|
||||
}
|
||||
|
||||
switch path := req.URL.Path; {
|
||||
case path == p.RobotsPath:
|
||||
p.RobotsTxt(rw)
|
||||
case p.IsAllowedRequest(req):
|
||||
p.SkipAuthProxy(rw, req)
|
||||
case path == p.SignInPath:
|
||||
p.SignIn(rw, req)
|
||||
case path == p.SignOutPath:
|
||||
p.SignOut(rw, req)
|
||||
case path == p.OAuthStartPath:
|
||||
p.OAuthStart(rw, req)
|
||||
case path == p.OAuthCallbackPath:
|
||||
p.OAuthCallback(rw, req)
|
||||
case path == p.AuthOnlyPath:
|
||||
p.AuthOnly(rw, req)
|
||||
case path == p.UserInfoPath:
|
||||
p.UserInfo(rw, req)
|
||||
default:
|
||||
p.Proxy(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
// RobotsTxt disallows scraping pages from the OAuthProxy
|
||||
func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) {
|
||||
_, err := fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
|
||||
|
|
@ -496,6 +526,42 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m
|
|||
}
|
||||
}
|
||||
|
||||
// IsAllowedRequest is used to check if auth should be skipped for this request
|
||||
func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool {
|
||||
isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS"
|
||||
return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.isTrustedIP(req)
|
||||
}
|
||||
|
||||
// IsAllowedRoute is used to check if the request method & path is allowed without auth
|
||||
func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool {
|
||||
for _, route := range p.allowedRoutes {
|
||||
if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isTrustedIP is used to check if a request comes from a trusted client IP address.
|
||||
func (p *OAuthProxy) isTrustedIP(req *http.Request) bool {
|
||||
if p.trustedIPs == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req)
|
||||
if err != nil {
|
||||
logger.Errorf("Error obtaining real IP for trusted IP list: %v", err)
|
||||
// Possibly spoofed X-Real-IP header
|
||||
return false
|
||||
}
|
||||
|
||||
if remoteAddr == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return p.trustedIPs.Has(remoteAddr)
|
||||
}
|
||||
|
||||
// SignInPage writes the sing in template to the response
|
||||
func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
|
||||
prepareNoCache(rw)
|
||||
|
|
@ -507,7 +573,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
|
|||
}
|
||||
rw.WriteHeader(code)
|
||||
|
||||
redirectURL, err := p.GetRedirect(req)
|
||||
redirectURL, err := p.getAppRedirect(req)
|
||||
if err != nil {
|
||||
logger.Errorf("Error obtaining redirect: %v", err)
|
||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
||||
|
|
@ -566,195 +632,9 @@ func (p *OAuthProxy) ManualSignIn(req *http.Request) (string, bool) {
|
|||
return "", false
|
||||
}
|
||||
|
||||
// GetRedirect reads the query parameter to get the URL to redirect clients to
|
||||
// once authenticated with the OAuthProxy
|
||||
func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) {
|
||||
err = req.ParseForm()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
redirect = req.Header.Get("X-Auth-Request-Redirect")
|
||||
if req.Form.Get("rd") != "" {
|
||||
redirect = req.Form.Get("rd")
|
||||
}
|
||||
if !p.IsValidRedirect(redirect) {
|
||||
// Use RequestURI to preserve ?query
|
||||
redirect = req.URL.RequestURI()
|
||||
if strings.HasPrefix(redirect, p.ProxyPrefix) {
|
||||
redirect = "/"
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// splitHostPort separates host and port. If the port is not valid, it returns
|
||||
// the entire input as host, and it doesn't check the validity of the host.
|
||||
// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
|
||||
// *** taken from net/url, modified validOptionalPort() to accept ":*"
|
||||
func splitHostPort(hostport string) (host, port string) {
|
||||
host = hostport
|
||||
|
||||
colon := strings.LastIndexByte(host, ':')
|
||||
if colon != -1 && validOptionalPort(host[colon:]) {
|
||||
host, port = host[:colon], host[colon+1:]
|
||||
}
|
||||
|
||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
host = host[1 : len(host)-1]
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// validOptionalPort reports whether port is either an empty string
|
||||
// or matches /^:\d*$/
|
||||
// *** taken from net/url, modified to accept ":*"
|
||||
func validOptionalPort(port string) bool {
|
||||
if port == "" || port == ":*" {
|
||||
return true
|
||||
}
|
||||
if port[0] != ':' {
|
||||
return false
|
||||
}
|
||||
for _, b := range port[1:] {
|
||||
if b < '0' || b > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// IsValidRedirect checks whether the redirect URL is whitelisted
|
||||
func (p *OAuthProxy) IsValidRedirect(redirect string) bool {
|
||||
switch {
|
||||
case redirect == "":
|
||||
// The user didn't specify a redirect, should fallback to `/`
|
||||
return false
|
||||
case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect):
|
||||
return true
|
||||
case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"):
|
||||
redirectURL, err := url.Parse(redirect)
|
||||
if err != nil {
|
||||
logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect)
|
||||
return false
|
||||
}
|
||||
redirectHostname := redirectURL.Hostname()
|
||||
|
||||
for _, domain := range p.whitelistDomains {
|
||||
domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, "."))
|
||||
if domainHostname == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) {
|
||||
// the domain names match, now validate the ports
|
||||
// if the whitelisted domain's port is '*', allow all ports
|
||||
// if the whitelisted domain contains a specific port, only allow that port
|
||||
// if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https
|
||||
redirectPort := redirectURL.Port()
|
||||
if (domainPort == "*") ||
|
||||
(domainPort == redirectPort) ||
|
||||
(domainPort == "" && redirectPort == "") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect)
|
||||
return false
|
||||
default:
|
||||
logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsAllowedRequest is used to check if auth should be skipped for this request
|
||||
func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool {
|
||||
isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS"
|
||||
return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.IsTrustedIP(req)
|
||||
}
|
||||
|
||||
// IsAllowedRoute is used to check if the request method & path is allowed without auth
|
||||
func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool {
|
||||
for _, route := range p.allowedRoutes {
|
||||
if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
|
||||
var noCacheHeaders = map[string]string{
|
||||
"Expires": time.Unix(0, 0).Format(time.RFC1123),
|
||||
"Cache-Control": "no-cache, no-store, must-revalidate, max-age=0",
|
||||
"X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
|
||||
}
|
||||
|
||||
// prepareNoCache prepares headers for preventing browser caching.
|
||||
func prepareNoCache(w http.ResponseWriter) {
|
||||
// Set NoCache headers
|
||||
for k, v := range noCacheHeaders {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// IsTrustedIP is used to check if a request comes from a trusted client IP address.
|
||||
func (p *OAuthProxy) IsTrustedIP(req *http.Request) bool {
|
||||
if p.trustedIPs == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req)
|
||||
if err != nil {
|
||||
logger.Errorf("Error obtaining real IP for trusted IP list: %v", err)
|
||||
// Possibly spoofed X-Real-IP header
|
||||
return false
|
||||
}
|
||||
|
||||
if remoteAddr == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return p.trustedIPs.Has(remoteAddr)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) {
|
||||
prepareNoCache(rw)
|
||||
}
|
||||
|
||||
switch path := req.URL.Path; {
|
||||
case path == p.RobotsPath:
|
||||
p.RobotsTxt(rw)
|
||||
case p.IsAllowedRequest(req):
|
||||
p.SkipAuthProxy(rw, req)
|
||||
case path == p.SignInPath:
|
||||
p.SignIn(rw, req)
|
||||
case path == p.SignOutPath:
|
||||
p.SignOut(rw, req)
|
||||
case path == p.OAuthStartPath:
|
||||
p.OAuthStart(rw, req)
|
||||
case path == p.OAuthCallbackPath:
|
||||
p.OAuthCallback(rw, req)
|
||||
case path == p.AuthOnlyPath:
|
||||
p.AuthenticateOnly(rw, req)
|
||||
case path == p.UserInfoPath:
|
||||
p.UserInfo(rw, req)
|
||||
default:
|
||||
p.Proxy(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
// SignIn serves a page prompting users to sign in
|
||||
func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
|
||||
redirect, err := p.GetRedirect(req)
|
||||
redirect, err := p.getAppRedirect(req)
|
||||
if err != nil {
|
||||
logger.Errorf("Error obtaining redirect: %v", err)
|
||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
||||
|
|
@ -812,7 +692,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) {
|
|||
|
||||
// SignOut sends a response to clear the authentication cookie
|
||||
func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
|
||||
redirect, err := p.GetRedirect(req)
|
||||
redirect, err := p.getAppRedirect(req)
|
||||
if err != nil {
|
||||
logger.Errorf("Error obtaining redirect: %v", err)
|
||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
||||
|
|
@ -837,13 +717,13 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
|||
return
|
||||
}
|
||||
p.SetCSRFCookie(rw, req, nonce)
|
||||
redirect, err := p.GetRedirect(req)
|
||||
redirect, err := p.getAppRedirect(req)
|
||||
if err != nil {
|
||||
logger.Errorf("Error obtaining redirect: %v", err)
|
||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
||||
return
|
||||
}
|
||||
redirectURI := p.GetRedirectURI(util.GetRequestHost(req))
|
||||
redirectURI := p.getOAuthRedirectURI(req)
|
||||
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound)
|
||||
}
|
||||
|
||||
|
|
@ -866,7 +746,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
session, err := p.redeemCode(req.Context(), util.GetRequestHost(req), req.Form.Get("code"))
|
||||
session, err := p.redeemCode(req)
|
||||
if err != nil {
|
||||
logger.Errorf("Error redeeming code during OAuth2 callback: %v", err)
|
||||
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error")
|
||||
|
|
@ -925,16 +805,50 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
// AuthenticateOnly checks whether the user is currently logged in
|
||||
func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) {
|
||||
func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, error) {
|
||||
code := req.Form.Get("code")
|
||||
if code == "" {
|
||||
return nil, providers.ErrMissingCode
|
||||
}
|
||||
|
||||
redirectURI := p.getOAuthRedirectURI(req)
|
||||
s, err := p.provider.Redeem(req.Context(), redirectURI, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error {
|
||||
var err error
|
||||
if s.Email == "" {
|
||||
s.Email, err = p.provider.GetEmailAddress(ctx, s)
|
||||
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return p.provider.EnrichSession(ctx, s)
|
||||
}
|
||||
|
||||
// AuthOnly checks whether the user is currently logged in (both authentication
|
||||
// and optional authorization).
|
||||
func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := p.getAuthenticatedSession(rw, req)
|
||||
if err != nil {
|
||||
http.Error(rw, "unauthorized request", http.StatusUnauthorized)
|
||||
http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Unauthorized cases need to return 403 to prevent infinite redirects with
|
||||
// subrequest architectures
|
||||
if !authOnlyAuthorize(req, session) {
|
||||
http.Error(rw, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// we are authenticated
|
||||
p.addHeadersForProxying(rw, req, session)
|
||||
p.addHeadersForProxying(rw, session)
|
||||
p.headersChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusAccepted)
|
||||
})).ServeHTTP(rw, req)
|
||||
|
|
@ -952,13 +866,13 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
|||
switch err {
|
||||
case nil:
|
||||
// we are authenticated
|
||||
p.addHeadersForProxying(rw, req, session)
|
||||
p.addHeadersForProxying(rw, session)
|
||||
p.headersChain.Then(p.serveMux).ServeHTTP(rw, req)
|
||||
case ErrNeedsLogin:
|
||||
// we need to send the user to a login screen
|
||||
if isAjax(req) {
|
||||
// no point redirecting an AJAX request
|
||||
p.ErrorJSON(rw, http.StatusUnauthorized)
|
||||
p.errorJSON(rw, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -977,7 +891,195 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
|
|||
p.ErrorPage(rw, http.StatusInternalServerError,
|
||||
"Internal Error", "Internal Error")
|
||||
}
|
||||
}
|
||||
|
||||
// See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
|
||||
var noCacheHeaders = map[string]string{
|
||||
"Expires": time.Unix(0, 0).Format(time.RFC1123),
|
||||
"Cache-Control": "no-cache, no-store, must-revalidate, max-age=0",
|
||||
"X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
|
||||
}
|
||||
|
||||
// prepareNoCache prepares headers for preventing browser caching.
|
||||
func prepareNoCache(w http.ResponseWriter) {
|
||||
// Set NoCache headers
|
||||
for k, v := range noCacheHeaders {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// getOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will
|
||||
// redirect clients to once authenticated.
|
||||
// This is usually the OAuthProxy callback URL.
|
||||
func (p *OAuthProxy) getOAuthRedirectURI(req *http.Request) string {
|
||||
// if `p.redirectURL` already has a host, return it
|
||||
if p.redirectURL.Host != "" {
|
||||
return p.redirectURL.String()
|
||||
}
|
||||
|
||||
// Otherwise figure out the scheme + host from the request
|
||||
rd := *p.redirectURL
|
||||
rd.Host = requestutil.GetRequestHost(req)
|
||||
rd.Scheme = requestutil.GetRequestProto(req)
|
||||
|
||||
// If CookieSecure is true, return `https` no matter what
|
||||
// Not all reverse proxies set X-Forwarded-Proto
|
||||
if p.CookieSecure {
|
||||
rd.Scheme = schemeHTTPS
|
||||
}
|
||||
return rd.String()
|
||||
}
|
||||
|
||||
// getAppRedirect determines the full URL or URI path to redirect clients to
|
||||
// once authenticated with the OAuthProxy
|
||||
// Strategy priority (first legal result is used):
|
||||
// - `rd` querysting parameter
|
||||
// - `X-Auth-Request-Redirect` header
|
||||
// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
|
||||
// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
|
||||
// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
|
||||
// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
|
||||
// - `/`
|
||||
func (p *OAuthProxy) getAppRedirect(req *http.Request) (string, error) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// These redirect getter functions are strategies ordered by priority
|
||||
// for figuring out the redirect URL.
|
||||
type redirectGetter func(req *http.Request) string
|
||||
for _, rdGetter := range []redirectGetter{
|
||||
p.getRdQuerystringRedirect,
|
||||
p.getXAuthRequestRedirect,
|
||||
p.getXForwardedHeadersRedirect,
|
||||
p.getURIRedirect,
|
||||
} {
|
||||
redirect := rdGetter(req)
|
||||
// Call `p.IsValidRedirect` again here a final time to be safe
|
||||
if redirect != "" && p.IsValidRedirect(redirect) {
|
||||
return redirect, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "/", nil
|
||||
}
|
||||
|
||||
func isForwardedRequest(req *http.Request) bool {
|
||||
return requestutil.IsProxied(req) &&
|
||||
req.Host != requestutil.GetRequestHost(req)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) hasProxyPrefix(path string) bool {
|
||||
return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix))
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) validateRedirect(redirect string, errorFormat string) string {
|
||||
if p.IsValidRedirect(redirect) {
|
||||
return redirect
|
||||
}
|
||||
if redirect != "" {
|
||||
logger.Errorf(errorFormat, redirect)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getRdQuerystringRedirect handles this getAppRedirect strategy:
|
||||
// - `rd` querysting parameter
|
||||
func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string {
|
||||
return p.validateRedirect(
|
||||
req.Form.Get("rd"),
|
||||
"Invalid redirect provided in rd querystring parameter: %s",
|
||||
)
|
||||
}
|
||||
|
||||
// getXAuthRequestRedirect handles this getAppRedirect strategy:
|
||||
// - `X-Auth-Request-Redirect` Header
|
||||
func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string {
|
||||
return p.validateRedirect(
|
||||
req.Header.Get("X-Auth-Request-Redirect"),
|
||||
"Invalid redirect provided in X-Auth-Request-Redirect header: %s",
|
||||
)
|
||||
}
|
||||
|
||||
// getXForwardedHeadersRedirect handles these getAppRedirect strategies:
|
||||
// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
|
||||
// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
|
||||
func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string {
|
||||
if !isForwardedRequest(req) {
|
||||
return ""
|
||||
}
|
||||
|
||||
uri := requestutil.GetRequestURI(req)
|
||||
if p.hasProxyPrefix(uri) {
|
||||
uri = "/"
|
||||
}
|
||||
|
||||
redirect := fmt.Sprintf(
|
||||
"%s://%s%s",
|
||||
requestutil.GetRequestProto(req),
|
||||
requestutil.GetRequestHost(req),
|
||||
uri,
|
||||
)
|
||||
|
||||
return p.validateRedirect(redirect,
|
||||
"Invalid redirect generated from X-Forwarded-* headers: %s")
|
||||
}
|
||||
|
||||
// getURIRedirect handles these getAppRedirect strategies:
|
||||
// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
|
||||
// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
|
||||
// - `/`
|
||||
func (p *OAuthProxy) getURIRedirect(req *http.Request) string {
|
||||
redirect := p.validateRedirect(
|
||||
requestutil.GetRequestURI(req),
|
||||
"Invalid redirect generated from X-Forwarded-Uri header: %s",
|
||||
)
|
||||
if redirect == "" {
|
||||
redirect = req.URL.RequestURI()
|
||||
}
|
||||
|
||||
if p.hasProxyPrefix(redirect) {
|
||||
return "/"
|
||||
}
|
||||
return redirect
|
||||
}
|
||||
|
||||
// splitHostPort separates host and port. If the port is not valid, it returns
|
||||
// the entire input as host, and it doesn't check the validity of the host.
|
||||
// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
|
||||
// *** taken from net/url, modified validOptionalPort() to accept ":*"
|
||||
func splitHostPort(hostport string) (host, port string) {
|
||||
host = hostport
|
||||
|
||||
colon := strings.LastIndexByte(host, ':')
|
||||
if colon != -1 && validOptionalPort(host[colon:]) {
|
||||
host, port = host[:colon], host[colon+1:]
|
||||
}
|
||||
|
||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
host = host[1 : len(host)-1]
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// validOptionalPort reports whether port is either an empty string
|
||||
// or matches /^:\d*$/
|
||||
// *** taken from net/url, modified to accept ":*"
|
||||
func validOptionalPort(port string) bool {
|
||||
if port == "" || port == ":*" {
|
||||
return true
|
||||
}
|
||||
if port[0] != ':' {
|
||||
return false
|
||||
}
|
||||
for _, b := range port[1:] {
|
||||
if b < '0' || b > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so
|
||||
|
|
@ -989,7 +1091,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
|
|||
var session *sessionsapi.SessionState
|
||||
|
||||
getSession := p.sessionChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
session = middleware.GetRequestScope(req).Session
|
||||
session = middlewareapi.GetRequestScope(req).Session
|
||||
}))
|
||||
getSession.ServeHTTP(rw, req)
|
||||
|
||||
|
|
@ -1016,8 +1118,55 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
|
|||
return session, nil
|
||||
}
|
||||
|
||||
// authOnlyAuthorize handles special authorization logic that is only done
|
||||
// on the AuthOnly endpoint for use with Nginx subrequest architectures.
|
||||
//
|
||||
// TODO (@NickMeves): This method is a placeholder to be extended but currently
|
||||
// fails the linter. Remove the nolint when functionality expands.
|
||||
//
|
||||
//nolint:S1008
|
||||
func authOnlyAuthorize(req *http.Request, s *sessionsapi.SessionState) bool {
|
||||
// Allow secondary group restrictions based on the `allowed_groups`
|
||||
// querystring parameter
|
||||
if !checkAllowedGroups(req, s) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func checkAllowedGroups(req *http.Request, s *sessionsapi.SessionState) bool {
|
||||
allowedGroups := extractAllowedGroups(req)
|
||||
if len(allowedGroups) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, group := range s.Groups {
|
||||
if _, ok := allowedGroups[group]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func extractAllowedGroups(req *http.Request) map[string]struct{} {
|
||||
groups := map[string]struct{}{}
|
||||
|
||||
query := req.URL.Query()
|
||||
for _, allowedGroups := range query["allowed_groups"] {
|
||||
for _, group := range strings.Split(allowedGroups, ",") {
|
||||
if group != "" {
|
||||
groups[group] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return groups
|
||||
}
|
||||
|
||||
// addHeadersForProxying adds the appropriate headers the request / response for proxying
|
||||
func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) {
|
||||
func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, session *sessionsapi.SessionState) {
|
||||
if session.Email == "" {
|
||||
rw.Header().Set("GAP-Auth", session.User)
|
||||
} else {
|
||||
|
|
@ -1029,16 +1178,24 @@ func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Req
|
|||
func isAjax(req *http.Request) bool {
|
||||
acceptValues := req.Header.Values("Accept")
|
||||
const ajaxReq = applicationJSON
|
||||
for _, v := range acceptValues {
|
||||
if v == ajaxReq {
|
||||
// Iterate over multiple Accept headers, i.e.
|
||||
// Accept: application/json
|
||||
// Accept: text/plain
|
||||
for _, mimeTypes := range acceptValues {
|
||||
// Iterate over multiple mimetypes in a single header, i.e.
|
||||
// Accept: application/json, text/plain, */*
|
||||
for _, mimeType := range strings.Split(mimeTypes, ",") {
|
||||
mimeType = strings.TrimSpace(mimeType)
|
||||
if mimeType == ajaxReq {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ErrorJSON returns the error code with an application/json mime type
|
||||
func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) {
|
||||
// errorJSON returns the error code with an application/json mime type
|
||||
func (p *OAuthProxy) errorJSON(rw http.ResponseWriter, code int) {
|
||||
rw.Header().Set("Content-Type", applicationJSON)
|
||||
rw.WriteHeader(code)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import (
|
|||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/mbland/hmacauth"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
|
|
@ -414,8 +415,9 @@ func Test_redeemCode(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = proxy.redeemCode(context.Background(), "www.example.com", "")
|
||||
assert.Error(t, err)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
_, err = proxy.redeemCode(req)
|
||||
assert.Equal(t, providers.ErrMissingCode, err)
|
||||
}
|
||||
|
||||
func Test_enrichSession(t *testing.T) {
|
||||
|
|
@ -1197,18 +1199,20 @@ func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
|||
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
||||
}
|
||||
|
||||
func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
|
||||
func NewAuthOnlyEndpointTest(querystring string, modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
|
||||
pcTest, err := NewProcessCookieTestWithOptionsModifiers(modifiers...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pcTest.req, _ = http.NewRequest("GET",
|
||||
pcTest.opts.ProxyPrefix+"/auth", nil)
|
||||
pcTest.req, _ = http.NewRequest(
|
||||
"GET",
|
||||
fmt.Sprintf("%s/auth%s", pcTest.opts.ProxyPrefix, querystring),
|
||||
nil)
|
||||
return pcTest, nil
|
||||
}
|
||||
|
||||
func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
||||
test, err := NewAuthOnlyEndpointTest()
|
||||
test, err := NewAuthOnlyEndpointTest("")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -1226,7 +1230,7 @@ func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
||||
test, err := NewAuthOnlyEndpointTest()
|
||||
test, err := NewAuthOnlyEndpointTest("")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -1234,11 +1238,11 @@ func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
|||
test.proxy.ServeHTTP(test.rw, test.req)
|
||||
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
||||
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
||||
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
|
||||
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
|
||||
}
|
||||
|
||||
func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
||||
test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) {
|
||||
test, err := NewAuthOnlyEndpointTest("", func(opts *options.Options) {
|
||||
opts.Cookie.Expire = time.Duration(24) * time.Hour
|
||||
})
|
||||
if err != nil {
|
||||
|
|
@ -1254,11 +1258,11 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
|||
test.proxy.ServeHTTP(test.rw, test.req)
|
||||
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
||||
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
||||
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
|
||||
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
|
||||
}
|
||||
|
||||
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
||||
test, err := NewAuthOnlyEndpointTest()
|
||||
test, err := NewAuthOnlyEndpointTest("")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
@ -1273,7 +1277,7 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
|||
test.proxy.ServeHTTP(test.rw, test.req)
|
||||
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
||||
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
||||
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
|
||||
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
|
||||
}
|
||||
|
||||
func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
||||
|
|
@ -1746,8 +1750,9 @@ func TestRequestSignature(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGetRedirect(t *testing.T) {
|
||||
func Test_getAppRedirect(t *testing.T) {
|
||||
opts := baseTestOptions()
|
||||
opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com", ".example.com:8443")
|
||||
err := validation.Validate(opts)
|
||||
assert.NoError(t, err)
|
||||
require.NotEmpty(t, opts.ProxyPrefix)
|
||||
|
|
@ -1759,28 +1764,144 @@ func TestGetRedirect(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
headers map[string]string
|
||||
reverseProxy bool
|
||||
expectedRedirect string
|
||||
}{
|
||||
{
|
||||
name: "request outside of ProxyPrefix redirects to original URL",
|
||||
url: "/foo/bar",
|
||||
headers: nil,
|
||||
reverseProxy: false,
|
||||
expectedRedirect: "/foo/bar",
|
||||
},
|
||||
{
|
||||
name: "request with query preserves query",
|
||||
url: "/foo?bar",
|
||||
headers: nil,
|
||||
reverseProxy: false,
|
||||
expectedRedirect: "/foo?bar",
|
||||
},
|
||||
{
|
||||
name: "request under ProxyPrefix redirects to root",
|
||||
url: proxy.ProxyPrefix + "/foo/bar",
|
||||
headers: nil,
|
||||
reverseProxy: false,
|
||||
expectedRedirect: "/",
|
||||
},
|
||||
{
|
||||
name: "proxied request outside of ProxyPrefix redirects to proxied URL",
|
||||
url: "https://oauth.example.com/foo/bar",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "https",
|
||||
"X-Forwarded-Host": "a-service.example.com",
|
||||
"X-Forwarded-Uri": "/foo/bar",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedRedirect: "https://a-service.example.com/foo/bar",
|
||||
},
|
||||
{
|
||||
name: "non-proxied request with spoofed proxy headers wouldn't redirect",
|
||||
url: "https://oauth.example.com/foo?bar",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "https",
|
||||
"X-Forwarded-Host": "a-service.example.com",
|
||||
"X-Forwarded-Uri": "/foo/bar",
|
||||
},
|
||||
reverseProxy: false,
|
||||
expectedRedirect: "/foo?bar",
|
||||
},
|
||||
{
|
||||
name: "proxied request under ProxyPrefix redirects to root",
|
||||
url: "https://oauth.example.com" + proxy.ProxyPrefix + "/foo/bar",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "https",
|
||||
"X-Forwarded-Host": "a-service.example.com",
|
||||
"X-Forwarded-Uri": proxy.ProxyPrefix + "/foo/bar",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedRedirect: "https://a-service.example.com/",
|
||||
},
|
||||
{
|
||||
name: "proxied request with port under ProxyPrefix redirects to root",
|
||||
url: "https://oauth.example.com" + proxy.ProxyPrefix + "/foo/bar",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "https",
|
||||
"X-Forwarded-Host": "a-service.example.com:8443",
|
||||
"X-Forwarded-Uri": proxy.ProxyPrefix + "/foo/bar",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedRedirect: "https://a-service.example.com:8443/",
|
||||
},
|
||||
{
|
||||
name: "proxied request with missing uri header would still redirect to desired redirect",
|
||||
url: "https://oauth.example.com/foo?bar",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "https",
|
||||
"X-Forwarded-Host": "a-service.example.com",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedRedirect: "https://a-service.example.com/foo?bar",
|
||||
},
|
||||
{
|
||||
name: "request with headers proxy not being set (and reverse proxy enabled) would still redirect to desired redirect",
|
||||
url: "https://oauth.example.com/foo?bar",
|
||||
headers: nil,
|
||||
reverseProxy: true,
|
||||
expectedRedirect: "/foo?bar",
|
||||
},
|
||||
{
|
||||
name: "proxied request with X-Auth-Request-Redirect being set outside of ProxyPrefix redirects to proxied URL",
|
||||
url: "https://oauth.example.com/foo/bar",
|
||||
headers: map[string]string{
|
||||
"X-Auth-Request-Redirect": "https://a-service.example.com/foo/bar",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedRedirect: "https://a-service.example.com/foo/bar",
|
||||
},
|
||||
{
|
||||
name: "proxied request with rd query string redirects to proxied URL",
|
||||
url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fbar",
|
||||
headers: nil,
|
||||
reverseProxy: false,
|
||||
expectedRedirect: "https://a-service.example.com/foo/bar",
|
||||
},
|
||||
{
|
||||
name: "proxied request with rd query string and all headers set (and reverse proxy not enabled) redirects to proxied URL on rd query string",
|
||||
url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fjazz",
|
||||
headers: map[string]string{
|
||||
"X-Auth-Request-Redirect": "https://a-service.example.com/foo/baz",
|
||||
"X-Forwarded-Proto": "http",
|
||||
"X-Forwarded-Host": "another-service.example.com",
|
||||
"X-Forwarded-Uri": "/seasons/greetings",
|
||||
},
|
||||
reverseProxy: false,
|
||||
expectedRedirect: "https://a-service.example.com/foo/jazz",
|
||||
},
|
||||
{
|
||||
name: "proxied request with rd query string and some headers set redirects to proxied URL on rd query string",
|
||||
url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fbaz",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "https",
|
||||
"X-Forwarded-Host": "another-service.example.com",
|
||||
"X-Forwarded-Uri": "/seasons/greetings",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedRedirect: "https://a-service.example.com/foo/baz",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", tt.url, nil)
|
||||
redirect, err := proxy.GetRedirect(req)
|
||||
for header, value := range tt.headers {
|
||||
if value != "" {
|
||||
req.Header.Add(header, value)
|
||||
}
|
||||
}
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{
|
||||
ReverseProxy: tt.reverseProxy,
|
||||
})
|
||||
redirect, err := proxy.getAppRedirect(req)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedRedirect, redirect)
|
||||
|
|
@ -1848,6 +1969,13 @@ func TestAjaxUnauthorizedRequest2(t *testing.T) {
|
|||
testAjaxUnauthorizedRequest(t, header)
|
||||
}
|
||||
|
||||
func TestAjaxUnauthorizedRequestAccept1(t *testing.T) {
|
||||
header := make(http.Header)
|
||||
header.Add("Accept", "application/json, text/plain, */*")
|
||||
|
||||
testAjaxUnauthorizedRequest(t, header)
|
||||
}
|
||||
|
||||
func TestAjaxForbiddendRequest(t *testing.T) {
|
||||
test, err := newAjaxRequestTest()
|
||||
if err != nil {
|
||||
|
|
@ -1960,7 +2088,7 @@ func TestGetJwtSession(t *testing.T) {
|
|||
verifier := oidc.NewVerifier("https://issuer.example.com", keyset,
|
||||
&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true})
|
||||
|
||||
test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) {
|
||||
test, err := NewAuthOnlyEndpointTest("", func(opts *options.Options) {
|
||||
opts.InjectRequestHeaders = []options.Header{
|
||||
{
|
||||
Name: "Authorization",
|
||||
|
|
@ -2028,7 +2156,6 @@ func TestGetJwtSession(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}
|
||||
|
||||
opts.SkipJwtBearerTokens = true
|
||||
opts.SetJWTBearerVerifiers(append(opts.GetJWTBearerVerifiers(), verifier))
|
||||
})
|
||||
|
|
@ -2692,32 +2819,106 @@ func TestProxyAllowedGroups(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAuthOnlyAllowedGroups(t *testing.T) {
|
||||
tests := []struct {
|
||||
testCases := []struct {
|
||||
name string
|
||||
allowedGroups []string
|
||||
groups []string
|
||||
expectUnauthorized bool
|
||||
querystring string
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{"NoAllowedGroups", []string{}, []string{}, false},
|
||||
{"NoAllowedGroupsUserHasGroups", []string{}, []string{"a", "b"}, false},
|
||||
{"UserInAllowedGroup", []string{"a"}, []string{"a", "b"}, false},
|
||||
{"UserNotInAllowedGroup", []string{"a"}, []string{"c"}, true},
|
||||
{
|
||||
name: "NoAllowedGroups",
|
||||
allowedGroups: []string{},
|
||||
groups: []string{},
|
||||
querystring: "",
|
||||
expectedStatusCode: http.StatusAccepted,
|
||||
},
|
||||
{
|
||||
name: "NoAllowedGroupsUserHasGroups",
|
||||
allowedGroups: []string{},
|
||||
groups: []string{"a", "b"},
|
||||
querystring: "",
|
||||
expectedStatusCode: http.StatusAccepted,
|
||||
},
|
||||
{
|
||||
name: "UserInAllowedGroup",
|
||||
allowedGroups: []string{"a"},
|
||||
groups: []string{"a", "b"},
|
||||
querystring: "",
|
||||
expectedStatusCode: http.StatusAccepted,
|
||||
},
|
||||
{
|
||||
name: "UserNotInAllowedGroup",
|
||||
allowedGroups: []string{"a"},
|
||||
groups: []string{"c"},
|
||||
querystring: "",
|
||||
expectedStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "UserInQuerystringGroup",
|
||||
allowedGroups: []string{"a", "b"},
|
||||
groups: []string{"a", "c"},
|
||||
querystring: "?allowed_groups=a",
|
||||
expectedStatusCode: http.StatusAccepted,
|
||||
},
|
||||
{
|
||||
name: "UserInMultiParamQuerystringGroup",
|
||||
allowedGroups: []string{"a", "b"},
|
||||
groups: []string{"b"},
|
||||
querystring: "?allowed_groups=a&allowed_groups=b,d",
|
||||
expectedStatusCode: http.StatusAccepted,
|
||||
},
|
||||
{
|
||||
name: "UserInOnlyQuerystringGroup",
|
||||
allowedGroups: []string{},
|
||||
groups: []string{"a", "c"},
|
||||
querystring: "?allowed_groups=a,b",
|
||||
expectedStatusCode: http.StatusAccepted,
|
||||
},
|
||||
{
|
||||
name: "UserInDelimitedQuerystringGroup",
|
||||
allowedGroups: []string{"a", "b", "c"},
|
||||
groups: []string{"c"},
|
||||
querystring: "?allowed_groups=a,c",
|
||||
expectedStatusCode: http.StatusAccepted,
|
||||
},
|
||||
{
|
||||
name: "UserNotInQuerystringGroup",
|
||||
allowedGroups: []string{},
|
||||
groups: []string{"c"},
|
||||
querystring: "?allowed_groups=a,b",
|
||||
expectedStatusCode: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "UserInConfigGroupNotInQuerystringGroup",
|
||||
allowedGroups: []string{"a", "b", "c"},
|
||||
groups: []string{"c"},
|
||||
querystring: "?allowed_groups=a,b",
|
||||
expectedStatusCode: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "UserInQuerystringGroupNotInConfigGroup",
|
||||
allowedGroups: []string{"a", "b"},
|
||||
groups: []string{"c"},
|
||||
querystring: "?allowed_groups=b,c",
|
||||
expectedStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
emailAddress := "test"
|
||||
created := time.Now()
|
||||
|
||||
session := &sessions.SessionState{
|
||||
Groups: tt.groups,
|
||||
Groups: tc.groups,
|
||||
Email: emailAddress,
|
||||
AccessToken: "oauth_token",
|
||||
CreatedAt: &created,
|
||||
}
|
||||
|
||||
test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) {
|
||||
opts.AllowedGroups = tt.allowedGroups
|
||||
test, err := NewAuthOnlyEndpointTest(tc.querystring, func(opts *options.Options) {
|
||||
opts.AllowedGroups = tc.allowedGroups
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
@ -2728,11 +2929,7 @@ func TestAuthOnlyAllowedGroups(t *testing.T) {
|
|||
|
||||
test.proxy.ServeHTTP(test.rw, test.req)
|
||||
|
||||
if tt.expectUnauthorized {
|
||||
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
||||
} else {
|
||||
assert.Equal(t, http.StatusAccepted, test.rw.Code)
|
||||
}
|
||||
assert.Equal(t, tc.expectedStatusCode, test.rw.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
package middleware_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// TestMiddlewareSuite and related tests are in a *_test package
|
||||
// to prevent circular imports with the `logger` package which uses
|
||||
// this functionality
|
||||
func TestMiddlewareSuite(t *testing.T) {
|
||||
logger.SetOutput(GinkgoWriter)
|
||||
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Middleware API")
|
||||
}
|
||||
|
|
@ -1,13 +1,26 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
type scopeKey string
|
||||
|
||||
// RequestScopeKey uses a typed string to reduce likelihood of clashing
|
||||
// with other context keys
|
||||
const RequestScopeKey scopeKey = "request-scope"
|
||||
|
||||
// RequestScope contains information regarding the request that is being made.
|
||||
// The RequestScope is used to pass information between different middlewares
|
||||
// within the chain.
|
||||
type RequestScope struct {
|
||||
// ReverseProxy tracks whether OAuth2-Proxy is operating in reverse proxy
|
||||
// mode and if request `X-Forwarded-*` headers should be trusted
|
||||
ReverseProxy bool
|
||||
|
||||
// Session details the authenticated users information (if it exists).
|
||||
Session *sessions.SessionState
|
||||
|
||||
|
|
@ -22,3 +35,19 @@ type RequestScope struct {
|
|||
// it was loaded or not.
|
||||
SessionRevalidated bool
|
||||
}
|
||||
|
||||
// GetRequestScope returns the current request scope from the given request
|
||||
func GetRequestScope(req *http.Request) *RequestScope {
|
||||
scope := req.Context().Value(RequestScopeKey)
|
||||
if scope == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return scope.(*RequestScope)
|
||||
}
|
||||
|
||||
// AddRequestScope adds a RequestScope to a request
|
||||
func AddRequestScope(req *http.Request, scope *RequestScope) *http.Request {
|
||||
ctx := context.WithValue(req.Context(), RequestScopeKey, scope)
|
||||
return req.WithContext(ctx)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,56 @@
|
|||
package middleware_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Scope Suite", func() {
|
||||
Context("GetRequestScope", func() {
|
||||
var request *http.Request
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
request, err = http.NewRequest("", "http://127.0.0.1/", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Context("with a scope", func() {
|
||||
var scope *middleware.RequestScope
|
||||
|
||||
BeforeEach(func() {
|
||||
scope = &middleware.RequestScope{}
|
||||
request = middleware.AddRequestScope(request, scope)
|
||||
})
|
||||
|
||||
It("returns the scope", func() {
|
||||
s := middleware.GetRequestScope(request)
|
||||
Expect(s).ToNot(BeNil())
|
||||
Expect(s).To(Equal(scope))
|
||||
})
|
||||
|
||||
Context("if the scope is then modified", func() {
|
||||
BeforeEach(func() {
|
||||
Expect(scope.SaveSession).To(BeFalse())
|
||||
scope.SaveSession = true
|
||||
})
|
||||
|
||||
It("returns the updated session", func() {
|
||||
s := middleware.GetRequestScope(request)
|
||||
Expect(s).ToNot(BeNil())
|
||||
Expect(s).To(Equal(scope))
|
||||
Expect(s.SaveSession).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("without a scope", func() {
|
||||
It("returns nil", func() {
|
||||
Expect(middleware.GetRequestScope(request)).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -36,7 +36,7 @@ type Options struct {
|
|||
TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file"`
|
||||
|
||||
AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"`
|
||||
KeycloakGroup string `flag:"keycloak-group" cfg:"keycloak_group"`
|
||||
KeycloakGroups []string `flag:"keycloak-group" cfg:"keycloak_groups"`
|
||||
AzureTenant string `flag:"azure-tenant" cfg:"azure_tenant"`
|
||||
BitbucketTeam string `flag:"bitbucket-team" cfg:"bitbucket_team"`
|
||||
BitbucketRepository string `flag:"bitbucket-repository" cfg:"bitbucket_repository"`
|
||||
|
|
@ -87,6 +87,7 @@ type Options struct {
|
|||
InsecureOIDCSkipIssuerVerification bool `flag:"insecure-oidc-skip-issuer-verification" cfg:"insecure_oidc_skip_issuer_verification"`
|
||||
SkipOIDCDiscovery bool `flag:"skip-oidc-discovery" cfg:"skip_oidc_discovery"`
|
||||
OIDCJwksURL string `flag:"oidc-jwks-url" cfg:"oidc_jwks_url"`
|
||||
OIDCEmailClaim string `flag:"oidc-email-claim" cfg:"oidc_email_claim"`
|
||||
OIDCGroupsClaim string `flag:"oidc-groups-claim" cfg:"oidc_groups_claim"`
|
||||
LoginURL string `flag:"login-url" cfg:"login_url"`
|
||||
RedeemURL string `flag:"redeem-url" cfg:"redeem_url"`
|
||||
|
|
@ -148,11 +149,12 @@ func NewOptions() *Options {
|
|||
SkipAuthPreflight: false,
|
||||
Prompt: "", // Change to "login" when ApprovalPrompt officially deprecated
|
||||
ApprovalPrompt: "force",
|
||||
UserIDClaim: "email",
|
||||
InsecureOIDCAllowUnverifiedEmail: false,
|
||||
SkipOIDCDiscovery: false,
|
||||
Logging: loggingDefaults(),
|
||||
OIDCGroupsClaim: "groups",
|
||||
UserIDClaim: providers.OIDCEmailClaim, // Deprecated: Use OIDCEmailClaim
|
||||
OIDCEmailClaim: providers.OIDCEmailClaim,
|
||||
OIDCGroupsClaim: providers.OIDCGroupsClaim,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -179,7 +181,7 @@ func NewFlagSet() *pflag.FlagSet {
|
|||
|
||||
flagSet.StringSlice("email-domain", []string{}, "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email")
|
||||
flagSet.StringSlice("whitelist-domain", []string{}, "allowed domains for redirection after authentication. Prefix domain with a . to allow subdomains (eg .example.com)")
|
||||
flagSet.String("keycloak-group", "", "restrict login to members of this group.")
|
||||
flagSet.StringSlice("keycloak-group", []string{}, "restrict logins to members of these groups (may be given multiple times)")
|
||||
flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.")
|
||||
flagSet.String("bitbucket-team", "", "restrict logins to members of this team")
|
||||
flagSet.String("bitbucket-repository", "", "restrict logins to user with access to this repository")
|
||||
|
|
@ -226,7 +228,8 @@ func NewFlagSet() *pflag.FlagSet {
|
|||
flagSet.Bool("insecure-oidc-skip-issuer-verification", false, "Do not verify if issuer matches OIDC discovery URL")
|
||||
flagSet.Bool("skip-oidc-discovery", false, "Skip OIDC discovery and use manually supplied Endpoints")
|
||||
flagSet.String("oidc-jwks-url", "", "OpenID Connect JWKS URL (ie: https://www.googleapis.com/oauth2/v3/certs)")
|
||||
flagSet.String("oidc-groups-claim", "groups", "which claim contains the user groups")
|
||||
flagSet.String("oidc-groups-claim", providers.OIDCGroupsClaim, "which OIDC claim contains the user groups")
|
||||
flagSet.String("oidc-email-claim", providers.OIDCEmailClaim, "which OIDC claim contains the user's email")
|
||||
flagSet.String("login-url", "", "Authentication endpoint")
|
||||
flagSet.String("redeem-url", "", "Token redemption endpoint")
|
||||
flagSet.String("profile-url", "", "Profile access endpoint")
|
||||
|
|
@ -243,7 +246,7 @@ func NewFlagSet() *pflag.FlagSet {
|
|||
flagSet.String("pubjwk-url", "", "JWK pubkey access endpoint: required by login.gov")
|
||||
flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints")
|
||||
|
||||
flagSet.String("user-id-claim", "email", "which claim contains the user ID")
|
||||
flagSet.String("user-id-claim", providers.OIDCEmailClaim, "(DEPRECATED for `oidc-email-claim`) which claim contains the user ID")
|
||||
flagSet.StringSlice("allowed-group", []string{}, "restrict logins to members of this group (may be given multiple times)")
|
||||
|
||||
flagSet.AddFlagSet(cookieFlagSet())
|
||||
|
|
|
|||
|
|
@ -9,14 +9,14 @@ import (
|
|||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
||||
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||
)
|
||||
|
||||
// MakeCookie constructs a cookie from the given parameters,
|
||||
// discovering the domain from the request if not specified.
|
||||
func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie {
|
||||
if domain != "" {
|
||||
host := util.GetRequestHost(req)
|
||||
host := requestutil.GetRequestHost(req)
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
|
|
@ -48,7 +48,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO
|
|||
// If nothing matches, create the cookie with the shortest domain
|
||||
defaultDomain := ""
|
||||
if len(cookieOpts.Domains) > 0 {
|
||||
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", util.GetRequestHost(req), strings.Join(cookieOpts.Domains, ","))
|
||||
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", requestutil.GetRequestHost(req), strings.Join(cookieOpts.Domains, ","))
|
||||
defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1]
|
||||
}
|
||||
return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite))
|
||||
|
|
@ -57,7 +57,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO
|
|||
// GetCookieDomain returns the correct cookie domain given a list of domains
|
||||
// by checking the X-Fowarded-Host and host header of an an http request
|
||||
func GetCookieDomain(req *http.Request, cookieDomains []string) string {
|
||||
host := util.GetRequestHost(req)
|
||||
host := requestutil.GetRequestHost(req)
|
||||
for _, domain := range cookieDomains {
|
||||
if strings.HasSuffix(host, domain) {
|
||||
return domain
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import (
|
|||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
||||
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||
)
|
||||
|
||||
// AuthStatus defines the different types of auth logging that occur
|
||||
|
|
@ -197,7 +197,7 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu
|
|||
|
||||
err := l.authTemplate.Execute(l.writer, authLogMessageData{
|
||||
Client: client,
|
||||
Host: util.GetRequestHost(req),
|
||||
Host: requestutil.GetRequestHost(req),
|
||||
Protocol: req.Proto,
|
||||
RequestMethod: req.Method,
|
||||
Timestamp: FormatTimestamp(now),
|
||||
|
|
@ -251,7 +251,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url.
|
|||
|
||||
err := l.reqTemplate.Execute(l.writer, reqLogMessageData{
|
||||
Client: client,
|
||||
Host: util.GetRequestHost(req),
|
||||
Host: requestutil.GetRequestHost(req),
|
||||
Protocol: req.Proto,
|
||||
RequestDuration: fmt.Sprintf("%0.3f", duration),
|
||||
RequestMethod: req.Method,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
|
|
@ -23,7 +24,7 @@ func NewBasicAuthSessionLoader(validator basic.Validator) alice.Constructor {
|
|||
// If a session was loaded by a previous handler, it will not be replaced.
|
||||
func loadBasicAuthSession(validator basic.Validator, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := GetRequestScope(req)
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
// If scope is nil, this will panic.
|
||||
// A scope should always be injected before this handler is called.
|
||||
if scope.Session != nil {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
|
@ -40,8 +39,7 @@ var _ = Describe("Basic Auth Session Suite", func() {
|
|||
// Set up the request with the authorization header and a request scope
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
req.Header.Set("Authorization", in.authorizationHeader)
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
req = req.WithContext(contextWithScope)
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
|
|
@ -57,7 +55,7 @@ var _ = Describe("Basic Auth Session Suite", func() {
|
|||
// from the scope
|
||||
var gotSession *sessionsapi.SessionState
|
||||
handler := NewBasicAuthSessionLoader(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session
|
||||
gotSession = middlewareapi.GetRequestScope(r).Session
|
||||
}))
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/header"
|
||||
)
|
||||
|
|
@ -61,7 +62,7 @@ func newRequestHeaderInjector(headers []options.Header) (alice.Constructor, erro
|
|||
|
||||
func injectRequestHeaders(injector header.Injector, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := GetRequestScope(req)
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
|
||||
// If scope is nil, this will panic.
|
||||
// A scope should always be injected before this handler is called.
|
||||
|
|
@ -92,7 +93,7 @@ func newResponseHeaderInjector(headers []options.Header) (alice.Constructor, err
|
|||
|
||||
func injectResponseHeaders(injector header.Injector, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := GetRequestScope(req)
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
|
||||
// If scope is nil, this will panic.
|
||||
// A scope should always be injected before this handler is called.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
|
@ -31,8 +30,7 @@ var _ = Describe("Headers Suite", func() {
|
|||
|
||||
// Set up the request with a request scope
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
req = req.WithContext(contextWithScope)
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
req.Header = in.initialHeaders.Clone()
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
|
@ -218,8 +216,7 @@ var _ = Describe("Headers Suite", func() {
|
|||
|
||||
// Set up the request with a request scope
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
req = req.WithContext(contextWithScope)
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
for key, values := range in.initialHeaders {
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ type jwtSessionLoader struct {
|
|||
// If a session was loaded by a previous handler, it will not be replaced.
|
||||
func (j *jwtSessionLoader) loadSession(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := GetRequestScope(req)
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
// If scope is nil, this will panic.
|
||||
// A scope should always be injected before this handler is called.
|
||||
if scope.Session != nil {
|
||||
|
|
|
|||
|
|
@ -103,8 +103,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
|
|||
// Set up the request with the authorization header and a request scope
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
req.Header.Set("Authorization", in.authorizationHeader)
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
req = req.WithContext(contextWithScope)
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
|
|
@ -116,7 +115,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
|
|||
// from the scope
|
||||
var gotSession *sessionsapi.SessionState
|
||||
handler := NewJwtSessionLoader(sessionLoaders)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session
|
||||
gotSession = middlewareapi.GetRequestScope(r).Session
|
||||
}))
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
|
||||
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||
)
|
||||
|
||||
const httpsScheme = "https"
|
||||
|
|
@ -26,10 +26,11 @@ func NewRedirectToHTTPS(httpsPort string) alice.Constructor {
|
|||
// to the port from the httpsAddress given.
|
||||
func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
proto := req.Header.Get("X-Forwarded-Proto")
|
||||
if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == "") {
|
||||
// Only care about the connection to us being HTTPS if the proto is empty,
|
||||
// otherwise the proto is source of truth
|
||||
proto := requestutil.GetRequestProto(req)
|
||||
if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == req.URL.Scheme) {
|
||||
// Only care about the connection to us being HTTPS if the proto wasn't
|
||||
// from a trusted `X-Forwarded-Proto` (proto == req.URL.Scheme).
|
||||
// Otherwise the proto is source of truth
|
||||
next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
|
@ -41,7 +42,7 @@ func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler {
|
|||
|
||||
// Set the Host in case the targetURL still does not have one
|
||||
// or it isn't X-Forwarded-Host aware
|
||||
targetURL.Host = util.GetRequestHost(req)
|
||||
targetURL.Host = requestutil.GetRequestHost(req)
|
||||
|
||||
// Overwrite the port if the original request was to a non-standard port
|
||||
if targetURL.Port() != "" {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"net/http/httptest"
|
||||
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/ginkgo/extensions/table"
|
||||
. "github.com/onsi/gomega"
|
||||
|
|
@ -21,6 +22,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
|||
requestString string
|
||||
useTLS bool
|
||||
headers map[string]string
|
||||
reverseProxy bool
|
||||
expectedStatus int
|
||||
expectedBody string
|
||||
expectedLocation string
|
||||
|
|
@ -35,6 +37,10 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
|||
if in.useTLS {
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
}
|
||||
scope := &middlewareapi.RequestScope{
|
||||
ReverseProxy: in.reverseProxy,
|
||||
}
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
|
|
@ -52,6 +58,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
|||
requestString: "http://example.com",
|
||||
useTLS: false,
|
||||
headers: map[string]string{},
|
||||
reverseProxy: false,
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
|
|
@ -60,6 +67,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
|||
requestString: "https://example.com",
|
||||
useTLS: true,
|
||||
headers: map[string]string{},
|
||||
reverseProxy: false,
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
|
|
@ -69,15 +77,28 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
|||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTPS",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
Entry("without TLS and X-Forwarded-Proto=HTTPS but ReverseProxy not set", &requestTableInput{
|
||||
requestString: "http://example.com",
|
||||
useTLS: false,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTPS",
|
||||
},
|
||||
reverseProxy: false,
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
}),
|
||||
Entry("with TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{
|
||||
requestString: "https://example.com",
|
||||
useTLS: true,
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTPS",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
|
|
@ -87,6 +108,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
|||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
|
|
@ -96,6 +118,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
|||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
|
|
@ -105,6 +128,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
|||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTP",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
|
|
@ -115,6 +139,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
|||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "HTTP",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
|
|
@ -125,6 +150,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
|||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "http",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
|
|
@ -135,6 +161,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
|||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "http",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com"),
|
||||
expectedLocation: "https://example.com",
|
||||
|
|
@ -143,6 +170,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
|||
requestString: "http://example.com:8080",
|
||||
useTLS: false,
|
||||
headers: map[string]string{},
|
||||
reverseProxy: false,
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://example.com:8443"),
|
||||
expectedLocation: "https://example.com:8443",
|
||||
|
|
@ -151,6 +179,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
|||
requestString: "https://example.com:8443",
|
||||
useTLS: true,
|
||||
headers: map[string]string{},
|
||||
reverseProxy: false,
|
||||
expectedStatus: 200,
|
||||
expectedBody: "test",
|
||||
}),
|
||||
|
|
@ -161,6 +190,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
|||
requestString: "/",
|
||||
useTLS: false,
|
||||
expectedStatus: 308,
|
||||
reverseProxy: false,
|
||||
expectedBody: permanentRedirectBody("https://example.com/"),
|
||||
expectedLocation: "https://example.com/",
|
||||
}),
|
||||
|
|
@ -171,6 +201,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
|
|||
"X-Forwarded-Proto": "HTTP",
|
||||
"X-Forwarded-Host": "external.example.com",
|
||||
},
|
||||
reverseProxy: true,
|
||||
expectedStatus: 308,
|
||||
expectedBody: permanentRedirectBody("https://external.example.com"),
|
||||
expectedLocation: "https://external.example.com",
|
||||
|
|
|
|||
|
|
@ -1,39 +1,20 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
)
|
||||
|
||||
type scopeKey string
|
||||
|
||||
// requestScopeKey uses a typed string to reduce likelihood of clasing
|
||||
// with other context keys
|
||||
const requestScopeKey scopeKey = "request-scope"
|
||||
|
||||
func NewScope() alice.Constructor {
|
||||
return addScope
|
||||
}
|
||||
|
||||
// addScope injects a new request scope into the request context.
|
||||
func addScope(next http.Handler) http.Handler {
|
||||
func NewScope(reverseProxy bool) alice.Constructor {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := &middlewareapi.RequestScope{}
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
requestWithScope := req.WithContext(contextWithScope)
|
||||
next.ServeHTTP(rw, requestWithScope)
|
||||
scope := &middlewareapi.RequestScope{
|
||||
ReverseProxy: reverseProxy,
|
||||
}
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
next.ServeHTTP(rw, req)
|
||||
})
|
||||
}
|
||||
|
||||
// GetRequestScope returns the current request scope from the given request
|
||||
func GetRequestScope(req *http.Request) *middlewareapi.RequestScope {
|
||||
scope := req.Context().Value(requestScopeKey)
|
||||
if scope == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return scope.(*middlewareapi.RequestScope)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
|
|
@ -21,8 +20,11 @@ var _ = Describe("Scope Suite", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
rw = httptest.NewRecorder()
|
||||
})
|
||||
|
||||
handler := NewScope()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Context("ReverseProxy is false", func() {
|
||||
BeforeEach(func() {
|
||||
handler := NewScope(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextRequest = r
|
||||
w.WriteHeader(200)
|
||||
}))
|
||||
|
|
@ -30,64 +32,37 @@ var _ = Describe("Scope Suite", func() {
|
|||
})
|
||||
|
||||
It("does not add a scope to the original request", func() {
|
||||
Expect(request.Context().Value(requestScopeKey)).To(BeNil())
|
||||
Expect(request.Context().Value(middlewareapi.RequestScopeKey)).To(BeNil())
|
||||
})
|
||||
|
||||
It("cannot load a scope from the original request using GetRequestScope", func() {
|
||||
Expect(GetRequestScope(request)).To(BeNil())
|
||||
Expect(middlewareapi.GetRequestScope(request)).To(BeNil())
|
||||
})
|
||||
|
||||
It("adds a scope to the request for the next handler", func() {
|
||||
Expect(nextRequest.Context().Value(requestScopeKey)).ToNot(BeNil())
|
||||
Expect(nextRequest.Context().Value(middlewareapi.RequestScopeKey)).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("can load a scope from the next handler's request using GetRequestScope", func() {
|
||||
Expect(GetRequestScope(nextRequest)).ToNot(BeNil())
|
||||
scope := middlewareapi.GetRequestScope(nextRequest)
|
||||
Expect(scope).ToNot(BeNil())
|
||||
Expect(scope.ReverseProxy).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("GetRequestScope", func() {
|
||||
var request *http.Request
|
||||
|
||||
Context("ReverseProxy is true", func() {
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
request, err = http.NewRequest("", "http://127.0.0.1/", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
handler := NewScope(true)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextRequest = r
|
||||
w.WriteHeader(200)
|
||||
}))
|
||||
handler.ServeHTTP(rw, request)
|
||||
})
|
||||
|
||||
Context("with a scope", func() {
|
||||
var scope *middlewareapi.RequestScope
|
||||
|
||||
BeforeEach(func() {
|
||||
scope = &middlewareapi.RequestScope{}
|
||||
contextWithScope := context.WithValue(request.Context(), requestScopeKey, scope)
|
||||
request = request.WithContext(contextWithScope)
|
||||
})
|
||||
|
||||
It("returns the scope", func() {
|
||||
s := GetRequestScope(request)
|
||||
Expect(s).ToNot(BeNil())
|
||||
Expect(s).To(Equal(scope))
|
||||
})
|
||||
|
||||
Context("if the scope is then modified", func() {
|
||||
BeforeEach(func() {
|
||||
Expect(scope.SaveSession).To(BeFalse())
|
||||
scope.SaveSession = true
|
||||
})
|
||||
|
||||
It("returns the updated session", func() {
|
||||
s := GetRequestScope(request)
|
||||
Expect(s).ToNot(BeNil())
|
||||
Expect(s).To(Equal(scope))
|
||||
Expect(s.SaveSession).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("without a scope", func() {
|
||||
It("returns nil", func() {
|
||||
Expect(GetRequestScope(request)).To(BeNil())
|
||||
It("return a scope where the ReverseProxy field is true", func() {
|
||||
scope := middlewareapi.GetRequestScope(nextRequest)
|
||||
Expect(scope).ToNot(BeNil())
|
||||
Expect(scope.ReverseProxy).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/justinas/alice"
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
)
|
||||
|
|
@ -59,7 +60,7 @@ type storedSessionLoader struct {
|
|||
// If a session was loader by a previous handler, it will not be replaced.
|
||||
func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
scope := GetRequestScope(req)
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
// If scope is nil, this will panic.
|
||||
// A scope should always be injected before this handler is called.
|
||||
if scope.Session != nil {
|
||||
|
|
|
|||
|
|
@ -104,8 +104,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||
// Set up the request with the request headesr and a request scope
|
||||
req := httptest.NewRequest("", "/", nil)
|
||||
req.Header = in.requestHeaders
|
||||
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
|
||||
req = req.WithContext(contextWithScope)
|
||||
req = middlewareapi.AddRequestScope(req, scope)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
|
|
@ -120,7 +119,7 @@ var _ = Describe("Stored Session Suite", func() {
|
|||
// from the scope
|
||||
var gotSession *sessionsapi.SessionState
|
||||
handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session
|
||||
gotSession = middlewareapi.GetRequestScope(r).Session
|
||||
}))
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
)
|
||||
|
||||
// GetRequestProto returns the request scheme or X-Forwarded-Proto if present
|
||||
// and the request is proxied.
|
||||
func GetRequestProto(req *http.Request) string {
|
||||
proto := req.Header.Get("X-Forwarded-Proto")
|
||||
if !IsProxied(req) || proto == "" {
|
||||
proto = req.URL.Scheme
|
||||
}
|
||||
return proto
|
||||
}
|
||||
|
||||
// GetRequestHost returns the request host header or X-Forwarded-Host if
|
||||
// present and the request is proxied.
|
||||
func GetRequestHost(req *http.Request) string {
|
||||
host := req.Header.Get("X-Forwarded-Host")
|
||||
if !IsProxied(req) || host == "" {
|
||||
host = req.Host
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// GetRequestURI return the request URI or X-Forwarded-Uri if present and the
|
||||
// request is proxied.
|
||||
func GetRequestURI(req *http.Request) string {
|
||||
uri := req.Header.Get("X-Forwarded-Uri")
|
||||
if !IsProxied(req) || uri == "" {
|
||||
// Use RequestURI to preserve ?query
|
||||
uri = req.URL.RequestURI()
|
||||
}
|
||||
return uri
|
||||
}
|
||||
|
||||
// IsProxied determines if a request was from a proxy based on the RequestScope
|
||||
// ReverseProxy tracker.
|
||||
func IsProxied(req *http.Request) bool {
|
||||
scope := middlewareapi.GetRequestScope(req)
|
||||
if scope == nil {
|
||||
return false
|
||||
}
|
||||
return scope.ReverseProxy
|
||||
}
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
package util_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// TestRequestUtilSuite and related tests are in a *_test package
|
||||
// to prevent circular imports with the `logger` package which uses
|
||||
// this functionality
|
||||
func TestRequestUtilSuite(t *testing.T) {
|
||||
logger.SetOutput(GinkgoWriter)
|
||||
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Request Utils")
|
||||
}
|
||||
|
|
@ -0,0 +1,131 @@
|
|||
package util_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Util Suite", func() {
|
||||
const (
|
||||
proto = "http"
|
||||
host = "www.oauth2proxy.test"
|
||||
uri = "/test/endpoint"
|
||||
)
|
||||
var req *http.Request
|
||||
|
||||
BeforeEach(func() {
|
||||
req = httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
fmt.Sprintf("%s://%s%s", proto, host, uri),
|
||||
nil,
|
||||
)
|
||||
})
|
||||
|
||||
Context("GetRequestHost", func() {
|
||||
Context("IsProxied is false", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
|
||||
})
|
||||
|
||||
It("returns the host", func() {
|
||||
Expect(util.GetRequestHost(req)).To(Equal(host))
|
||||
})
|
||||
|
||||
It("ignores X-Forwarded-Host and returns the host", func() {
|
||||
req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text")
|
||||
Expect(util.GetRequestHost(req)).To(Equal(host))
|
||||
})
|
||||
})
|
||||
|
||||
Context("IsProxied is true", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{
|
||||
ReverseProxy: true,
|
||||
})
|
||||
})
|
||||
|
||||
It("returns the host if X-Forwarded-Host is not present", func() {
|
||||
Expect(util.GetRequestHost(req)).To(Equal(host))
|
||||
})
|
||||
|
||||
It("returns the X-Forwarded-Host when present", func() {
|
||||
req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text")
|
||||
Expect(util.GetRequestHost(req)).To(Equal("external.oauth2proxy.text"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("GetRequestProto", func() {
|
||||
Context("IsProxied is false", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
|
||||
})
|
||||
|
||||
It("returns the scheme", func() {
|
||||
Expect(util.GetRequestProto(req)).To(Equal(proto))
|
||||
})
|
||||
|
||||
It("ignores X-Forwarded-Proto and returns the scheme", func() {
|
||||
req.Header.Add("X-Forwarded-Proto", "https")
|
||||
Expect(util.GetRequestProto(req)).To(Equal(proto))
|
||||
})
|
||||
})
|
||||
|
||||
Context("IsProxied is true", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{
|
||||
ReverseProxy: true,
|
||||
})
|
||||
})
|
||||
|
||||
It("returns the scheme if X-Forwarded-Proto is not present", func() {
|
||||
Expect(util.GetRequestProto(req)).To(Equal(proto))
|
||||
})
|
||||
|
||||
It("returns the X-Forwarded-Proto when present", func() {
|
||||
req.Header.Add("X-Forwarded-Proto", "https")
|
||||
Expect(util.GetRequestProto(req)).To(Equal("https"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("GetRequestURI", func() {
|
||||
Context("IsProxied is false", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
|
||||
})
|
||||
|
||||
It("returns the URI", func() {
|
||||
Expect(util.GetRequestURI(req)).To(Equal(uri))
|
||||
})
|
||||
|
||||
It("ignores X-Forwarded-Uri and returns the URI", func() {
|
||||
req.Header.Add("X-Forwarded-Uri", "/some/other/path")
|
||||
Expect(util.GetRequestURI(req)).To(Equal(uri))
|
||||
})
|
||||
})
|
||||
|
||||
Context("IsProxied is true", func() {
|
||||
BeforeEach(func() {
|
||||
req = middleware.AddRequestScope(req, &middleware.RequestScope{
|
||||
ReverseProxy: true,
|
||||
})
|
||||
})
|
||||
|
||||
It("returns the URI if X-Forwarded-Uri is not present", func() {
|
||||
Expect(util.GetRequestURI(req)).To(Equal(uri))
|
||||
})
|
||||
|
||||
It("returns the X-Forwarded-Uri when present", func() {
|
||||
req.Header.Add("X-Forwarded-Uri", "/some/other/path")
|
||||
Expect(util.GetRequestURI(req)).To(Equal("/some/other/path"))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -5,7 +5,6 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
|
|
@ -220,12 +219,12 @@ func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) {
|
|||
if len(cookies) == 0 {
|
||||
return nil, fmt.Errorf("could not find cookie %s", cookieName)
|
||||
}
|
||||
return joinCookies(cookies)
|
||||
return joinCookies(cookies, cookieName)
|
||||
}
|
||||
|
||||
// joinCookies takes a slice of cookies from the request and reconstructs the
|
||||
// full session cookie
|
||||
func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) {
|
||||
func joinCookies(cookies []*http.Cookie, cookieName string) (*http.Cookie, error) {
|
||||
if len(cookies) == 0 {
|
||||
return nil, fmt.Errorf("list of cookies must be > 0")
|
||||
}
|
||||
|
|
@ -236,7 +235,7 @@ func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) {
|
|||
for i := 1; i < len(cookies); i++ {
|
||||
c.Value += cookies[i].Value
|
||||
}
|
||||
c.Name = strings.TrimRight(c.Name, "_0")
|
||||
c.Name = cookieName
|
||||
return c, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -154,9 +154,58 @@ func Test_splitCookie_joinCookies(t *testing.T) {
|
|||
Value: value,
|
||||
}
|
||||
splitCookies := splitCookie(cookie)
|
||||
joinedCookie, err := joinCookies(splitCookies)
|
||||
joinedCookie, err := joinCookies(splitCookies, cookie.Name)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, *cookie, *joinedCookie)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_joinCookies_withUnderlineSuffix(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
CookieName string
|
||||
SplitOrder []int
|
||||
}{
|
||||
"Ascending order split with \"_\" suffix": {
|
||||
CookieName: "_cookie_name_",
|
||||
SplitOrder: []int{0, 1, 2, 3, 4},
|
||||
},
|
||||
"Descending order split with \"_\" suffix": {
|
||||
CookieName: "_cookie_name_",
|
||||
SplitOrder: []int{4, 3, 2, 1, 0},
|
||||
},
|
||||
"Arbitrary order split with \"_\" suffix": {
|
||||
CookieName: "_cookie_name_",
|
||||
SplitOrder: []int{3, 1, 2, 0, 4},
|
||||
},
|
||||
"Arbitrary order split with \"_0\" suffix": {
|
||||
CookieName: "_cookie_name_0",
|
||||
SplitOrder: []int{1, 3, 0, 2, 4},
|
||||
},
|
||||
"Arbitrary order split with \"_1\" suffix": {
|
||||
CookieName: "_cookie_name_1",
|
||||
SplitOrder: []int{4, 1, 3, 0, 2},
|
||||
},
|
||||
"Arbitrary order split with \"__\" suffix": {
|
||||
CookieName: "_cookie_name__",
|
||||
SplitOrder: []int{1, 0, 4, 3, 2},
|
||||
},
|
||||
}
|
||||
|
||||
for testName, testCase := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
cookieName := testCase.CookieName
|
||||
var splitCookies []*http.Cookie
|
||||
for _, splitSuffix := range testCase.SplitOrder {
|
||||
cookie := &http.Cookie{
|
||||
Name: splitCookieName(cookieName, splitSuffix),
|
||||
Value: strings.Repeat("v", 1000),
|
||||
}
|
||||
splitCookies = append(splitCookies, cookie)
|
||||
}
|
||||
joinedCookie, err := joinCookies(splitCookies, cookieName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, cookieName, joinedCookie.Name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import (
|
|||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func GetCertPool(paths []string) (*x509.CertPool, error) {
|
||||
|
|
@ -24,12 +23,3 @@ func GetCertPool(paths []string) (*x509.CertPool, error) {
|
|||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// GetRequestHost return the request host header or X-Forwarded-Host if present
|
||||
func GetRequestHost(req *http.Request) string {
|
||||
host := req.Header.Get("X-Forwarded-Host")
|
||||
if host == "" {
|
||||
host = req.Host
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,11 +4,9 @@ import (
|
|||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"io/ioutil"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
|
|
@ -97,16 +95,3 @@ func TestGetCertPool(t *testing.T) {
|
|||
expectedSubjects := []string{testCA1Subj, testCA2Subj}
|
||||
assert.Equal(t, expectedSubjects, got)
|
||||
}
|
||||
|
||||
func TestGetRequestHost(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
req := httptest.NewRequest("GET", "https://example.com", nil)
|
||||
host := GetRequestHost(req)
|
||||
g.Expect(host).To(Equal("example.com"))
|
||||
|
||||
proxyReq := httptest.NewRequest("GET", "http://internal.example.com", nil)
|
||||
proxyReq.Header.Add("X-Forwarded-Host", "external.example.com")
|
||||
extHost := GetRequestHost(proxyReq)
|
||||
g.Expect(extHost).To(Equal("external.example.com"))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -233,9 +233,19 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
|
|||
p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
|
||||
p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs)
|
||||
|
||||
// Make the OIDC Verifier accessible to all providers that can support it
|
||||
// Make the OIDC options available to all providers that support it
|
||||
p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail
|
||||
p.EmailClaim = o.OIDCEmailClaim
|
||||
p.GroupsClaim = o.OIDCGroupsClaim
|
||||
p.Verifier = o.GetOIDCVerifier()
|
||||
|
||||
// TODO (@NickMeves) - Remove This
|
||||
// Backwards Compatibility for Deprecated UserIDClaim option
|
||||
if o.OIDCEmailClaim == providers.OIDCEmailClaim &&
|
||||
o.UserIDClaim != providers.OIDCEmailClaim {
|
||||
p.EmailClaim = o.UserIDClaim
|
||||
}
|
||||
|
||||
p.SetAllowedGroups(o.AllowedGroups)
|
||||
|
||||
provider := providers.New(o.ProviderType, p)
|
||||
|
|
@ -253,7 +263,10 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
|
|||
p.SetRepo(o.GitHubRepo, o.GitHubToken)
|
||||
p.SetUsers(o.GitHubUsers)
|
||||
case *providers.KeycloakProvider:
|
||||
p.SetGroup(o.KeycloakGroup)
|
||||
// Backwards compatibility with `--keycloak-group` option
|
||||
if len(o.KeycloakGroups) > 0 {
|
||||
p.SetAllowedGroups(o.KeycloakGroups)
|
||||
}
|
||||
case *providers.GoogleProvider:
|
||||
if o.GoogleServiceAccountJSON != "" {
|
||||
file, err := os.Open(o.GoogleServiceAccountJSON)
|
||||
|
|
@ -273,14 +286,10 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
|
|||
p.SetTeam(o.BitbucketTeam)
|
||||
p.SetRepository(o.BitbucketRepository)
|
||||
case *providers.OIDCProvider:
|
||||
p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail
|
||||
p.UserIDClaim = o.UserIDClaim
|
||||
p.GroupsClaim = o.OIDCGroupsClaim
|
||||
if p.Verifier == nil {
|
||||
msgs = append(msgs, "oidc provider requires an oidc issuer URL")
|
||||
}
|
||||
case *providers.GitLabProvider:
|
||||
p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail
|
||||
p.Groups = o.GitLabGroup
|
||||
err := p.AddProjects(o.GitlabProjects)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -20,8 +20,6 @@ type GitLabProvider struct {
|
|||
|
||||
Groups []string
|
||||
Projects []*GitlabProject
|
||||
|
||||
AllowUnverifiedEmail bool
|
||||
}
|
||||
|
||||
// GitlabProject represents a Gitlab project constraint entity
|
||||
|
|
@ -103,7 +101,7 @@ func (p *GitLabProvider) Redeem(ctx context.Context, redirectURL, code string) (
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("token exchange: %v", err)
|
||||
}
|
||||
s, err = p.createSessionState(ctx, token)
|
||||
s, err = p.createSession(ctx, token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to update session: %v", err)
|
||||
}
|
||||
|
|
@ -162,7 +160,7 @@ func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.Ses
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to get token: %v", err)
|
||||
}
|
||||
newSession, err := p.createSessionState(ctx, token)
|
||||
newSession, err := p.createSession(ctx, token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to update session: %v", err)
|
||||
}
|
||||
|
|
@ -255,22 +253,21 @@ func (p *GitLabProvider) AddProjects(projects []string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *GitLabProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("token response did not contain an id_token")
|
||||
}
|
||||
|
||||
// Parse and verify ID Token payload.
|
||||
idToken, err := p.Verifier.Verify(ctx, rawIDToken)
|
||||
func (p *GitLabProvider) createSession(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
|
||||
idToken, err := p.verifyIDToken(ctx, token)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case ErrMissingIDToken:
|
||||
return nil, fmt.Errorf("token response did not contain an id_token")
|
||||
default:
|
||||
return nil, fmt.Errorf("could not verify id_token: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
created := time.Now()
|
||||
return &sessions.SessionState{
|
||||
AccessToken: token.AccessToken,
|
||||
IDToken: rawIDToken,
|
||||
IDToken: getIDToken(token),
|
||||
RefreshToken: token.RefreshToken,
|
||||
CreatedAt: &created,
|
||||
ExpiresOn: &idToken.Expiry,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package providers
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
|
|
@ -11,7 +12,6 @@ import (
|
|||
|
||||
type KeycloakProvider struct {
|
||||
*ProviderData
|
||||
Group string
|
||||
}
|
||||
|
||||
var _ Provider = (*KeycloakProvider)(nil)
|
||||
|
|
@ -47,6 +47,7 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
// NewKeycloakProvider creates a KeyCloakProvider using the passed ProviderData
|
||||
func NewKeycloakProvider(p *ProviderData) *KeycloakProvider {
|
||||
p.setProviderDefaults(providerDefaults{
|
||||
name: keycloakProviderName,
|
||||
|
|
@ -59,41 +60,39 @@ func NewKeycloakProvider(p *ProviderData) *KeycloakProvider {
|
|||
return &KeycloakProvider{ProviderData: p}
|
||||
}
|
||||
|
||||
func (p *KeycloakProvider) SetGroup(group string) {
|
||||
p.Group = group
|
||||
// EnrichSession uses the Keycloak userinfo endpoint to populate the session's
|
||||
// email and groups.
|
||||
func (p *KeycloakProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
|
||||
// Fallback to ValidateURL if ProfileURL not set for legacy compatibility
|
||||
profileURL := p.ValidateURL.String()
|
||||
if p.ProfileURL.String() != "" {
|
||||
profileURL = p.ProfileURL.String()
|
||||
}
|
||||
|
||||
func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
|
||||
json, err := requests.New(p.ValidateURL.String()).
|
||||
json, err := requests.New(profileURL).
|
||||
WithContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+s.AccessToken).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
logger.Errorf("failed making request %s", err)
|
||||
return "", err
|
||||
logger.Errorf("failed making request %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if p.Group != "" {
|
||||
var groups, err = json.Get("groups").Array()
|
||||
groups, err := json.Get("groups").StringArray()
|
||||
if err == nil {
|
||||
for _, group := range groups {
|
||||
if group != "" {
|
||||
s.Groups = append(s.Groups, group)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
email, err := json.Get("email").String()
|
||||
if err != nil {
|
||||
logger.Printf("groups not found %s", err)
|
||||
return "", err
|
||||
return fmt.Errorf("unable to extract email from userinfo endpoint: %v", err)
|
||||
}
|
||||
s.Email = email
|
||||
|
||||
var found = false
|
||||
for i := range groups {
|
||||
if groups[i].(string) == p.Group {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
logger.Printf("group not found, access denied")
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
return json.Get("email").String()
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,17 +2,24 @@ package providers
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/ginkgo/extensions/table"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testKeycloakProvider(hostname, group string) *KeycloakProvider {
|
||||
const (
|
||||
keycloakAccessToken = "eyJKeycloak.eyJAccess.Token"
|
||||
keycloakUserinfoPath = "/api/v3/user"
|
||||
)
|
||||
|
||||
func testKeycloakProvider(backend *httptest.Server) (*KeycloakProvider, error) {
|
||||
p := NewKeycloakProvider(
|
||||
&ProviderData{
|
||||
ProviderName: "",
|
||||
|
|
@ -22,63 +29,35 @@ func testKeycloakProvider(hostname, group string) *KeycloakProvider {
|
|||
ValidateURL: &url.URL{},
|
||||
Scope: ""})
|
||||
|
||||
if group != "" {
|
||||
p.SetGroup(group)
|
||||
if backend != nil {
|
||||
bURL, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hostname := bURL.Host
|
||||
|
||||
if hostname != "" {
|
||||
updateURL(p.Data().LoginURL, hostname)
|
||||
updateURL(p.Data().RedeemURL, hostname)
|
||||
updateURL(p.Data().ProfileURL, hostname)
|
||||
updateURL(p.Data().ValidateURL, hostname)
|
||||
}
|
||||
return p
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func testKeycloakBackend(payload string) *httptest.Server {
|
||||
path := "/api/v3/user"
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
url := r.URL
|
||||
if url.Path != path {
|
||||
w.WriteHeader(404)
|
||||
} else if !IsAuthorizedInHeader(r.Header) {
|
||||
w.WriteHeader(403)
|
||||
} else {
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(payload))
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func TestKeycloakProviderDefaults(t *testing.T) {
|
||||
p := testKeycloakProvider("", "")
|
||||
assert.NotEqual(t, nil, p)
|
||||
assert.Equal(t, "Keycloak", p.Data().ProviderName)
|
||||
assert.Equal(t, "https://keycloak.org/oauth/authorize",
|
||||
p.Data().LoginURL.String())
|
||||
assert.Equal(t, "https://keycloak.org/oauth/token",
|
||||
p.Data().RedeemURL.String())
|
||||
assert.Equal(t, "https://keycloak.org/api/v3/user",
|
||||
p.Data().ValidateURL.String())
|
||||
assert.Equal(t, "api", p.Data().Scope)
|
||||
}
|
||||
|
||||
func TestNewKeycloakProvider(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
// Test that defaults are set when calling for a new provider with nothing set
|
||||
var _ = Describe("Keycloak Provider Tests", func() {
|
||||
Context("New Provider Init", func() {
|
||||
It("uses defaults", func() {
|
||||
providerData := NewKeycloakProvider(&ProviderData{}).Data()
|
||||
g.Expect(providerData.ProviderName).To(Equal("Keycloak"))
|
||||
g.Expect(providerData.LoginURL.String()).To(Equal("https://keycloak.org/oauth/authorize"))
|
||||
g.Expect(providerData.RedeemURL.String()).To(Equal("https://keycloak.org/oauth/token"))
|
||||
g.Expect(providerData.ProfileURL.String()).To(Equal(""))
|
||||
g.Expect(providerData.ValidateURL.String()).To(Equal("https://keycloak.org/api/v3/user"))
|
||||
g.Expect(providerData.Scope).To(Equal("api"))
|
||||
}
|
||||
Expect(providerData.ProviderName).To(Equal("Keycloak"))
|
||||
Expect(providerData.LoginURL.String()).To(Equal("https://keycloak.org/oauth/authorize"))
|
||||
Expect(providerData.RedeemURL.String()).To(Equal("https://keycloak.org/oauth/token"))
|
||||
Expect(providerData.ProfileURL.String()).To(Equal(""))
|
||||
Expect(providerData.ValidateURL.String()).To(Equal("https://keycloak.org/api/v3/user"))
|
||||
Expect(providerData.Scope).To(Equal("api"))
|
||||
})
|
||||
|
||||
func TestKeycloakProviderOverrides(t *testing.T) {
|
||||
It("overrides defaults", func() {
|
||||
p := NewKeycloakProvider(
|
||||
&ProviderData{
|
||||
LoginURL: &url.URL{
|
||||
|
|
@ -89,75 +68,143 @@ func TestKeycloakProviderOverrides(t *testing.T) {
|
|||
Scheme: "https",
|
||||
Host: "example.com",
|
||||
Path: "/oauth/token"},
|
||||
ProfileURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "example.com",
|
||||
Path: "/api/v3/user"},
|
||||
ValidateURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "example.com",
|
||||
Path: "/api/v3/user"},
|
||||
Scope: "profile"})
|
||||
assert.NotEqual(t, nil, p)
|
||||
assert.Equal(t, "Keycloak", p.Data().ProviderName)
|
||||
assert.Equal(t, "https://example.com/oauth/auth",
|
||||
p.Data().LoginURL.String())
|
||||
assert.Equal(t, "https://example.com/oauth/token",
|
||||
p.Data().RedeemURL.String())
|
||||
assert.Equal(t, "https://example.com/api/v3/user",
|
||||
p.Data().ValidateURL.String())
|
||||
assert.Equal(t, "profile", p.Data().Scope)
|
||||
providerData := p.Data()
|
||||
|
||||
Expect(providerData.ProviderName).To(Equal("Keycloak"))
|
||||
Expect(providerData.LoginURL.String()).To(Equal("https://example.com/oauth/auth"))
|
||||
Expect(providerData.RedeemURL.String()).To(Equal("https://example.com/oauth/token"))
|
||||
Expect(providerData.ProfileURL.String()).To(Equal("https://example.com/api/v3/user"))
|
||||
Expect(providerData.ValidateURL.String()).To(Equal("https://example.com/api/v3/user"))
|
||||
Expect(providerData.Scope).To(Equal("profile"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("EnrichSession", func() {
|
||||
type enrichSessionTableInput struct {
|
||||
backendHandler http.HandlerFunc
|
||||
expectedError error
|
||||
expectedEmail string
|
||||
expectedGroups []string
|
||||
}
|
||||
|
||||
func TestKeycloakProviderGetEmailAddress(t *testing.T) {
|
||||
b := testKeycloakBackend("{\"email\": \"michael.bland@gsa.gov\"}")
|
||||
defer b.Close()
|
||||
DescribeTable("should return expected results",
|
||||
func(in enrichSessionTableInput) {
|
||||
backend := httptest.NewServer(in.backendHandler)
|
||||
p, err := testKeycloakProvider(backend)
|
||||
Expect(err).To(BeNil())
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testKeycloakProvider(bURL.Host, "")
|
||||
p.ProfileURL, err = url.Parse(
|
||||
fmt.Sprintf("%s%s", backend.URL, keycloakUserinfoPath),
|
||||
)
|
||||
Expect(err).To(BeNil())
|
||||
|
||||
session := CreateAuthorizedSession()
|
||||
email, err := p.GetEmailAddress(context.Background(), session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||
session := &sessions.SessionState{AccessToken: keycloakAccessToken}
|
||||
err = p.EnrichSession(context.Background(), session)
|
||||
|
||||
if in.expectedError != nil {
|
||||
Expect(err).To(Equal(in.expectedError))
|
||||
} else {
|
||||
Expect(err).To(BeNil())
|
||||
}
|
||||
|
||||
func TestKeycloakProviderGetEmailAddressAndGroup(t *testing.T) {
|
||||
b := testKeycloakBackend("{\"email\": \"michael.bland@gsa.gov\", \"groups\": [\"test-grp1\", \"test-grp2\"]}")
|
||||
defer b.Close()
|
||||
Expect(session.Email).To(Equal(in.expectedEmail))
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testKeycloakProvider(bURL.Host, "test-grp1")
|
||||
|
||||
session := CreateAuthorizedSession()
|
||||
email, err := p.GetEmailAddress(context.Background(), session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||
if in.expectedGroups != nil {
|
||||
Expect(session.Groups).To(Equal(in.expectedGroups))
|
||||
} else {
|
||||
Expect(session.Groups).To(BeNil())
|
||||
}
|
||||
|
||||
// Note that trying to trigger the "failed building request" case is not
|
||||
// practical, since the only way it can fail is if the URL fails to parse.
|
||||
func TestKeycloakProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
b := testKeycloakBackend("unused payload")
|
||||
defer b.Close()
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testKeycloakProvider(bURL.Host, "")
|
||||
|
||||
// We'll trigger a request failure by using an unexpected access
|
||||
// token. Alternatively, we could allow the parsing of the payload as
|
||||
// JSON to fail.
|
||||
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||
email, err := p.GetEmailAddress(context.Background(), session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
},
|
||||
Entry("email and multiple groups", enrichSessionTableInput{
|
||||
backendHandler: func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
_, err := w.Write([]byte(`
|
||||
{
|
||||
"email": "michael.bland@gsa.gov",
|
||||
"groups": [
|
||||
"test-grp1",
|
||||
"test-grp2"
|
||||
]
|
||||
}
|
||||
|
||||
func TestKeycloakProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||
b := testKeycloakBackend("{\"foo\": \"bar\"}")
|
||||
defer b.Close()
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testKeycloakProvider(bURL.Host, "")
|
||||
|
||||
session := CreateAuthorizedSession()
|
||||
email, err := p.GetEmailAddress(context.Background(), session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
`))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
},
|
||||
expectedError: nil,
|
||||
expectedEmail: "michael.bland@gsa.gov",
|
||||
expectedGroups: []string{"test-grp1", "test-grp2"},
|
||||
}),
|
||||
Entry("email and single group", enrichSessionTableInput{
|
||||
backendHandler: func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
_, err := w.Write([]byte(`
|
||||
{
|
||||
"email": "michael.bland@gsa.gov",
|
||||
"groups": ["test-grp1"]
|
||||
}
|
||||
`))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
},
|
||||
expectedError: nil,
|
||||
expectedEmail: "michael.bland@gsa.gov",
|
||||
expectedGroups: []string{"test-grp1"},
|
||||
}),
|
||||
Entry("email and no groups", enrichSessionTableInput{
|
||||
backendHandler: func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
_, err := w.Write([]byte(`
|
||||
{
|
||||
"email": "michael.bland@gsa.gov"
|
||||
}
|
||||
`))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
},
|
||||
expectedError: nil,
|
||||
expectedEmail: "michael.bland@gsa.gov",
|
||||
expectedGroups: nil,
|
||||
}),
|
||||
Entry("missing email", enrichSessionTableInput{
|
||||
backendHandler: func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
_, err := w.Write([]byte(`
|
||||
{
|
||||
"groups": [
|
||||
"test-grp1",
|
||||
"test-grp2"
|
||||
]
|
||||
}
|
||||
`))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
},
|
||||
expectedError: errors.New(
|
||||
"unable to extract email from userinfo endpoint: type assertion to string failed"),
|
||||
expectedEmail: "",
|
||||
expectedGroups: []string{"test-grp1", "test-grp2"},
|
||||
}),
|
||||
Entry("request failure", enrichSessionTableInput{
|
||||
backendHandler: func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(500)
|
||||
},
|
||||
expectedError: errors.New(`unexpected status "500": `),
|
||||
expectedEmail: "",
|
||||
expectedGroups: nil,
|
||||
}),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
|
|
|||
|
|
@ -2,29 +2,20 @@ package providers
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const emailClaim = "email"
|
||||
|
||||
// OIDCProvider represents an OIDC based Identity Provider
|
||||
type OIDCProvider struct {
|
||||
*ProviderData
|
||||
|
||||
AllowUnverifiedEmail bool
|
||||
UserIDClaim string
|
||||
GroupsClaim string
|
||||
}
|
||||
|
||||
// NewOIDCProvider initiates a new OIDCProvider
|
||||
|
|
@ -36,10 +27,10 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider {
|
|||
var _ Provider = (*OIDCProvider)(nil)
|
||||
|
||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||
func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := oauth2.Config{
|
||||
|
|
@ -52,23 +43,74 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s
|
|||
}
|
||||
token, err := c.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token exchange: %v", err)
|
||||
return nil, fmt.Errorf("token exchange failed: %v", err)
|
||||
}
|
||||
|
||||
// in the initial exchange the id token is mandatory
|
||||
idToken, err := p.findVerifiedIDToken(ctx, token)
|
||||
return p.createSession(ctx, token, false)
|
||||
}
|
||||
|
||||
// EnrichSession is called after Redeem to allow providers to enrich session fields
|
||||
// such as User, Email, Groups with provider specific API calls.
|
||||
func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
|
||||
if p.ProfileURL.String() == "" {
|
||||
if s.Email == "" {
|
||||
return errors.New("id_token did not contain an email and profileURL is not defined")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to get missing emails or groups from a profileURL
|
||||
if s.Email == "" || s.Groups == nil {
|
||||
err := p.enrichFromProfileURL(ctx, s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not verify id_token: %v", err)
|
||||
} else if idToken == nil {
|
||||
return nil, fmt.Errorf("token response did not contain an id_token")
|
||||
logger.Errorf("Warning: Profile URL request failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
s, err = p.createSessionState(ctx, token, idToken)
|
||||
// If a mandatory email wasn't set, error at this point.
|
||||
if s.Email == "" {
|
||||
return errors.New("neither the id_token nor the profileURL set an email")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// enrichFromProfileURL enriches a session's Email & Groups via the JSON response of
|
||||
// an OIDC profile URL
|
||||
func (p *OIDCProvider) enrichFromProfileURL(ctx context.Context, s *sessions.SessionState) error {
|
||||
respJSON, err := requests.New(p.ProfileURL.String()).
|
||||
WithContext(ctx).
|
||||
WithHeaders(makeOIDCHeader(s.AccessToken)).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to update session: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return
|
||||
email, err := respJSON.Get(p.EmailClaim).String()
|
||||
if err == nil && s.Email == "" {
|
||||
s.Email = email
|
||||
}
|
||||
|
||||
if len(s.Groups) > 0 {
|
||||
return nil
|
||||
}
|
||||
for _, group := range coerceArray(respJSON, p.GroupsClaim) {
|
||||
formatted, err := formatGroup(group)
|
||||
if err != nil {
|
||||
logger.Errorf("Warning: unable to format group of type %s with error %s",
|
||||
reflect.TypeOf(group), err)
|
||||
continue
|
||||
}
|
||||
s.Groups = append(s.Groups, formatted)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateSession checks that the session's IDToken is still valid
|
||||
func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
||||
_, err := p.Verifier.Verify(ctx, s.IDToken)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
|
|
@ -83,14 +125,16 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S
|
|||
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("refreshed access token %s (expired on %s)\n", s, s.ExpiresOn)
|
||||
logger.Printf("refreshed session: %s", s)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
|
||||
// redeemRefreshToken uses a RefreshToken with the RedeemURL to refresh the
|
||||
// Access Token and (probably) the ID Token.
|
||||
func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
|
||||
clientSecret, err := p.GetClientSecret()
|
||||
if err != nil {
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
c := oauth2.Config{
|
||||
|
|
@ -109,19 +153,14 @@ func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sessi
|
|||
return fmt.Errorf("failed to get token: %v", err)
|
||||
}
|
||||
|
||||
// in the token refresh response the id_token is optional
|
||||
idToken, err := p.findVerifiedIDToken(ctx, token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to extract id_token from response: %v", err)
|
||||
}
|
||||
|
||||
newSession, err := p.createSessionState(ctx, token, idToken)
|
||||
newSession, err := p.createSession(ctx, token, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable create new session state from response: %v", err)
|
||||
}
|
||||
|
||||
// It's possible that if the refresh token isn't in the token response the session will not contain an id token
|
||||
// if it doesn't it's probably better to retain the old one
|
||||
// It's possible that if the refresh token isn't in the token response the
|
||||
// session will not contain an id token.
|
||||
// If it doesn't it's probably better to retain the old one
|
||||
if newSession.IDToken != "" {
|
||||
s.IDToken = newSession.IDToken
|
||||
s.Email = newSession.Email
|
||||
|
|
@ -135,193 +174,62 @@ func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sessi
|
|||
s.CreatedAt = newSession.CreatedAt
|
||||
s.ExpiresOn = newSession.ExpiresOn
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
|
||||
|
||||
getIDToken := func() (string, bool) {
|
||||
rawIDToken, _ := token.Extra("id_token").(string)
|
||||
return rawIDToken, len(strings.TrimSpace(rawIDToken)) > 0
|
||||
}
|
||||
|
||||
if rawIDToken, present := getIDToken(); present {
|
||||
verifiedIDToken, err := p.Verifier.Verify(ctx, rawIDToken)
|
||||
return verifiedIDToken, err
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) {
|
||||
|
||||
var newSession *sessions.SessionState
|
||||
|
||||
if idToken == nil {
|
||||
newSession = &sessions.SessionState{}
|
||||
} else {
|
||||
var err error
|
||||
newSession, err = p.createSessionStateInternal(ctx, idToken, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
created := time.Now()
|
||||
newSession.AccessToken = token.AccessToken
|
||||
newSession.RefreshToken = token.RefreshToken
|
||||
newSession.CreatedAt = &created
|
||||
newSession.ExpiresOn = &token.Expiry
|
||||
return newSession, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateSessionFromToken converts Bearer IDTokens into sessions
|
||||
func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) {
|
||||
idToken, err := p.Verifier.Verify(ctx, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newSession, err := p.createSessionStateInternal(ctx, idToken, nil)
|
||||
ss, err := p.buildSessionFromClaims(idToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newSession.AccessToken = token
|
||||
newSession.IDToken = token
|
||||
newSession.RefreshToken = ""
|
||||
newSession.ExpiresOn = &idToken.Expiry
|
||||
|
||||
return newSession, nil
|
||||
// Allow empty Email in Bearer case since we can't hit the ProfileURL
|
||||
if ss.Email == "" {
|
||||
ss.Email = ss.User
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) createSessionStateInternal(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) {
|
||||
ss.AccessToken = token
|
||||
ss.IDToken = token
|
||||
ss.RefreshToken = ""
|
||||
ss.ExpiresOn = &idToken.Expiry
|
||||
|
||||
newSession := &sessions.SessionState{}
|
||||
|
||||
if idToken == nil {
|
||||
return newSession, nil
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
claims, err := p.findClaimsFromIDToken(ctx, idToken, token)
|
||||
// createSession takes an oauth2.Token and creates a SessionState from it.
|
||||
// It alters behavior if called from Redeem vs Refresh
|
||||
func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) {
|
||||
idToken, err := p.verifyIDToken(ctx, token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err)
|
||||
switch err {
|
||||
case ErrMissingIDToken:
|
||||
// IDToken is mandatory in Redeem but optional in Refresh
|
||||
if !refresh {
|
||||
return nil, errors.New("token response did not contain an id_token")
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("could not verify id_token: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if token != nil {
|
||||
newSession.IDToken = token.Extra("id_token").(string)
|
||||
}
|
||||
|
||||
newSession.Email = claims.UserID // TODO Rename SessionState.Email to .UserID in the near future
|
||||
|
||||
newSession.User = claims.Subject
|
||||
newSession.Groups = claims.Groups
|
||||
newSession.PreferredUsername = claims.PreferredUsername
|
||||
|
||||
verifyEmail := (p.UserIDClaim == emailClaim) && !p.AllowUnverifiedEmail
|
||||
if verifyEmail && claims.Verified != nil && !*claims.Verified {
|
||||
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.UserID)
|
||||
}
|
||||
|
||||
return newSession, nil
|
||||
}
|
||||
|
||||
// ValidateSessionState checks that the session's IDToken is still valid
|
||||
func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
||||
_, err := p.Verifier.Verify(ctx, s.IDToken)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*OIDCClaims, error) {
|
||||
claims := &OIDCClaims{}
|
||||
|
||||
// Extract default claims.
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse default id_token claims: %v", err)
|
||||
}
|
||||
// Extract custom claims.
|
||||
if err := idToken.Claims(&claims.rawClaims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse all id_token claims: %v", err)
|
||||
}
|
||||
|
||||
userID := claims.rawClaims[p.UserIDClaim]
|
||||
if userID != nil {
|
||||
claims.UserID = fmt.Sprint(userID)
|
||||
}
|
||||
|
||||
claims.Groups = p.extractGroupsFromRawClaims(claims.rawClaims)
|
||||
|
||||
// userID claim was not present or was empty in the ID Token
|
||||
if claims.UserID == "" {
|
||||
// BearerToken case, allow empty UserID
|
||||
// ProfileURL checks below won't work since we don't have an access token
|
||||
if token == nil {
|
||||
claims.UserID = claims.Subject
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
profileURL := p.ProfileURL.String()
|
||||
if profileURL == "" || token.AccessToken == "" {
|
||||
return nil, fmt.Errorf("id_token did not contain user ID claim (%q)", p.UserIDClaim)
|
||||
}
|
||||
|
||||
// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
|
||||
// contents at the profileURL contains the email.
|
||||
// Make a query to the userinfo endpoint, and attempt to locate the email from there.
|
||||
respJSON, err := requests.New(profileURL).
|
||||
WithContext(ctx).
|
||||
WithHeaders(makeOIDCHeader(token.AccessToken)).
|
||||
Do().
|
||||
UnmarshalJSON()
|
||||
ss, err := p.buildSessionFromClaims(idToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userID, err := respJSON.Get(p.UserIDClaim).String()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("neither id_token nor userinfo endpoint contained user ID claim (%q)", p.UserIDClaim)
|
||||
}
|
||||
ss.AccessToken = token.AccessToken
|
||||
ss.RefreshToken = token.RefreshToken
|
||||
ss.IDToken = getIDToken(token)
|
||||
|
||||
claims.UserID = userID
|
||||
}
|
||||
created := time.Now()
|
||||
ss.CreatedAt = &created
|
||||
ss.ExpiresOn = &token.Expiry
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) extractGroupsFromRawClaims(rawClaims map[string]interface{}) []string {
|
||||
groups := []string{}
|
||||
|
||||
rawGroups, ok := rawClaims[p.GroupsClaim].([]interface{})
|
||||
if rawGroups != nil && ok {
|
||||
for _, rawGroup := range rawGroups {
|
||||
formattedGroup, err := formatGroup(rawGroup)
|
||||
if err != nil {
|
||||
logger.Errorf("unable to format group of type %s with error %s", reflect.TypeOf(rawGroup), err)
|
||||
continue
|
||||
}
|
||||
groups = append(groups, formattedGroup)
|
||||
}
|
||||
}
|
||||
|
||||
return groups
|
||||
|
||||
}
|
||||
|
||||
func formatGroup(rawGroup interface{}) (string, error) {
|
||||
group, ok := rawGroup.(string)
|
||||
if !ok {
|
||||
jsonGroup, err := json.Marshal(rawGroup)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
group = string(jsonGroup)
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
type OIDCClaims struct {
|
||||
rawClaims map[string]interface{}
|
||||
UserID string
|
||||
Subject string `json:"sub"`
|
||||
Verified *bool `json:"email_verified"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Groups []string `json:"-"`
|
||||
return ss, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,42 +2,18 @@ package providers
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const accessToken = "access_token"
|
||||
const refreshToken = "refresh_token"
|
||||
const clientID = "https://test.myapp.com"
|
||||
const secret = "secret"
|
||||
|
||||
type idTokenClaims struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Phone string `json:"phone_number,omitempty"`
|
||||
Picture string `json:"picture,omitempty"`
|
||||
Groups interface{} `json:"groups,omitempty"`
|
||||
OtherGroups interface{} `json:"other_groups,omitempty"`
|
||||
jwt.StandardClaims
|
||||
}
|
||||
|
||||
type redeemTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
|
|
@ -46,88 +22,12 @@ type redeemTokenResponse struct {
|
|||
IDToken string `json:"id_token,omitempty"`
|
||||
}
|
||||
|
||||
var defaultIDToken idTokenClaims = idTokenClaims{
|
||||
"Jane Dobbs",
|
||||
"janed@me.com",
|
||||
"+4798765432",
|
||||
"http://mugbook.com/janed/me.jpg",
|
||||
[]string{"test:a", "test:b"},
|
||||
[]string{"test:c", "test:d"},
|
||||
jwt.StandardClaims{
|
||||
Audience: "https://test.myapp.com",
|
||||
ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(),
|
||||
Id: "id-some-id",
|
||||
IssuedAt: time.Now().Unix(),
|
||||
Issuer: "https://issuer.example.com",
|
||||
NotBefore: 0,
|
||||
Subject: "123456789",
|
||||
},
|
||||
}
|
||||
|
||||
var customGroupClaimIDToken idTokenClaims = idTokenClaims{
|
||||
"Jane Dobbs",
|
||||
"janed@me.com",
|
||||
"+4798765432",
|
||||
"http://mugbook.com/janed/me.jpg",
|
||||
[]map[string]interface{}{
|
||||
{
|
||||
"groupId": "Admin Group Id",
|
||||
"roles": []string{"Admin"},
|
||||
},
|
||||
},
|
||||
[]string{"test:c", "test:d"},
|
||||
jwt.StandardClaims{
|
||||
Audience: "https://test.myapp.com",
|
||||
ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(),
|
||||
Id: "id-some-id",
|
||||
IssuedAt: time.Now().Unix(),
|
||||
Issuer: "https://issuer.example.com",
|
||||
NotBefore: 0,
|
||||
Subject: "123456789",
|
||||
},
|
||||
}
|
||||
|
||||
var minimalIDToken idTokenClaims = idTokenClaims{
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
[]string{},
|
||||
[]string{},
|
||||
jwt.StandardClaims{
|
||||
Audience: "https://test.myapp.com",
|
||||
ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(),
|
||||
Id: "id-some-id",
|
||||
IssuedAt: time.Now().Unix(),
|
||||
Issuer: "https://issuer.example.com",
|
||||
NotBefore: 0,
|
||||
Subject: "minimal",
|
||||
},
|
||||
}
|
||||
|
||||
type fakeKeySetStub struct{}
|
||||
|
||||
func (fakeKeySetStub) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) {
|
||||
decodeString, err := base64.RawURLEncoding.DecodeString(strings.Split(jwt, ".")[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenClaims := &idTokenClaims{}
|
||||
err = json.Unmarshal(decodeString, tokenClaims)
|
||||
|
||||
if err != nil || tokenClaims.Id == "this-id-fails-validation" {
|
||||
return nil, fmt.Errorf("the validation failed for subject [%v]", tokenClaims.Subject)
|
||||
}
|
||||
|
||||
return decodeString, err
|
||||
}
|
||||
|
||||
func newOIDCProvider(serverURL *url.URL) *OIDCProvider {
|
||||
|
||||
providerData := &ProviderData{
|
||||
ProviderName: "oidc",
|
||||
ClientID: clientID,
|
||||
ClientSecret: secret,
|
||||
ClientID: oidcClientID,
|
||||
ClientSecret: oidcSecret,
|
||||
LoginURL: &url.URL{
|
||||
Scheme: serverURL.Scheme,
|
||||
Host: serverURL.Host,
|
||||
|
|
@ -145,17 +45,16 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider {
|
|||
Host: serverURL.Host,
|
||||
Path: "/api"},
|
||||
Scope: "openid profile offline_access",
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
Verifier: oidc.NewVerifier(
|
||||
"https://issuer.example.com",
|
||||
fakeKeySetStub{},
|
||||
&oidc.Config{ClientID: clientID},
|
||||
oidcIssuer,
|
||||
mockJWKS{},
|
||||
&oidc.Config{ClientID: oidcClientID},
|
||||
),
|
||||
}
|
||||
|
||||
p := &OIDCProvider{
|
||||
ProviderData: providerData,
|
||||
UserIDClaim: "email",
|
||||
}
|
||||
p := &OIDCProvider{ProviderData: providerData}
|
||||
|
||||
return p
|
||||
}
|
||||
|
|
@ -169,22 +68,7 @@ func newOIDCServer(body []byte) (*url.URL, *httptest.Server) {
|
|||
return u, s
|
||||
}
|
||||
|
||||
func newSignedTestIDToken(tokenClaims idTokenClaims) (string, error) {
|
||||
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
standardClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, tokenClaims)
|
||||
return standardClaims.SignedString(key)
|
||||
}
|
||||
|
||||
func newOauth2Token() *oauth2.Token {
|
||||
return &oauth2.Token{
|
||||
AccessToken: accessToken,
|
||||
TokenType: "Bearer",
|
||||
RefreshToken: refreshToken,
|
||||
Expiry: time.Time{}.Add(time.Duration(5) * time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
func newTestSetup(body []byte) (*httptest.Server, *OIDCProvider) {
|
||||
func newTestOIDCSetup(body []byte) (*httptest.Server, *OIDCProvider) {
|
||||
redeemURL, server := newOIDCServer(body)
|
||||
provider := newOIDCProvider(redeemURL)
|
||||
return server, provider
|
||||
|
|
@ -201,7 +85,7 @@ func TestOIDCProviderRedeem(t *testing.T) {
|
|||
IDToken: idToken,
|
||||
})
|
||||
|
||||
server, provider := newTestSetup(body)
|
||||
server, provider := newTestOIDCSetup(body)
|
||||
defer server.Close()
|
||||
|
||||
session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234")
|
||||
|
|
@ -224,8 +108,8 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) {
|
|||
IDToken: idToken,
|
||||
})
|
||||
|
||||
server, provider := newTestSetup(body)
|
||||
provider.UserIDClaim = "phone_number"
|
||||
server, provider := newTestOIDCSetup(body)
|
||||
provider.EmailClaim = "phone_number"
|
||||
defer server.Close()
|
||||
|
||||
session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234")
|
||||
|
|
@ -233,6 +117,333 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) {
|
|||
assert.Equal(t, defaultIDToken.Phone, session.Email)
|
||||
}
|
||||
|
||||
func TestOIDCProvider_EnrichSession(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
ExistingSession *sessions.SessionState
|
||||
EmailClaim string
|
||||
GroupsClaim string
|
||||
ProfileJSON map[string]interface{}
|
||||
ExpectedError error
|
||||
ExpectedSession *sessions.SessionState
|
||||
}{
|
||||
"Already Populated": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Email": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "found@email.com",
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Email: "found@email.com",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
|
||||
"Missing Email Only in Profile URL": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "found@email.com",
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Email: "found@email.com",
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Email with Custom Claim": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "weird",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"weird": "weird@claim.com",
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Email: "weird@claim.com",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Email not in Profile URL": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: errors.New("neither the id_token nor the profileURL set an email"),
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "missing.email",
|
||||
Groups: []string{"already", "populated"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"new", "thing"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups with Complex Groups in Profile URL": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": []map[string]interface{}{
|
||||
{
|
||||
"groupId": "Admin Group Id",
|
||||
"roles": []string{"Admin"},
|
||||
},
|
||||
},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups with Singleton Complex Group in Profile URL": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": map[string]interface{}{
|
||||
"groupId": "Admin Group Id",
|
||||
"roles": []string{"Admin"},
|
||||
},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Empty Groups Claims": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": []string{"new", "thing"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups with Custom Claim": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "roles",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"roles": []string{"new", "thing", "roles"},
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"new", "thing", "roles"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups String Profile URL Response": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: nil,
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
"groups": "singleton",
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
Groups: []string{"singleton"},
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
"Missing Groups in both Claims and Profile URL": {
|
||||
ExistingSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ProfileJSON: map[string]interface{}{
|
||||
"email": "new@thing.com",
|
||||
},
|
||||
ExpectedError: nil,
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "already",
|
||||
Email: "already@populated.com",
|
||||
IDToken: idToken,
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
},
|
||||
},
|
||||
}
|
||||
for testName, tc := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
jsonResp, err := json.Marshal(tc.ProfileJSON)
|
||||
assert.NoError(t, err)
|
||||
|
||||
server, provider := newTestOIDCSetup(jsonResp)
|
||||
provider.ProfileURL, err = url.Parse(server.URL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
provider.EmailClaim = tc.EmailClaim
|
||||
provider.GroupsClaim = tc.GroupsClaim
|
||||
defer server.Close()
|
||||
|
||||
err = provider.EnrichSession(context.Background(), tc.ExistingSession)
|
||||
assert.Equal(t, tc.ExpectedError, err)
|
||||
assert.Equal(t, *tc.ExpectedSession, *tc.ExistingSession)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
|
||||
|
||||
idToken, _ := newSignedTestIDToken(defaultIDToken)
|
||||
|
|
@ -243,7 +454,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
|
|||
RefreshToken: refreshToken,
|
||||
})
|
||||
|
||||
server, provider := newTestSetup(body)
|
||||
server, provider := newTestOIDCSetup(body)
|
||||
defer server.Close()
|
||||
|
||||
existingSession := &sessions.SessionState{
|
||||
|
|
@ -277,7 +488,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
|
|||
IDToken: idToken,
|
||||
})
|
||||
|
||||
server, provider := newTestSetup(body)
|
||||
server, provider := newTestOIDCSetup(body)
|
||||
defer server.Close()
|
||||
|
||||
existingSession := &sessions.SessionState{
|
||||
|
|
@ -300,48 +511,45 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
|
||||
const profileURLEmail = "janed@me.com"
|
||||
|
||||
testCases := map[string]struct {
|
||||
IDToken idTokenClaims
|
||||
GroupsClaim string
|
||||
ExpectedUser string
|
||||
ExpectedEmail string
|
||||
ExpectedGroups interface{}
|
||||
ExpectedGroups []string
|
||||
}{
|
||||
"Default IDToken": {
|
||||
IDToken: defaultIDToken,
|
||||
GroupsClaim: "groups",
|
||||
ExpectedUser: defaultIDToken.Subject,
|
||||
ExpectedEmail: defaultIDToken.Email,
|
||||
ExpectedUser: "123456789",
|
||||
ExpectedEmail: "janed@me.com",
|
||||
ExpectedGroups: []string{"test:a", "test:b"},
|
||||
},
|
||||
"Minimal IDToken with no email claim": {
|
||||
IDToken: minimalIDToken,
|
||||
GroupsClaim: "groups",
|
||||
ExpectedUser: minimalIDToken.Subject,
|
||||
ExpectedEmail: minimalIDToken.Subject,
|
||||
ExpectedGroups: []string{},
|
||||
ExpectedUser: "123456789",
|
||||
ExpectedEmail: "123456789",
|
||||
ExpectedGroups: nil,
|
||||
},
|
||||
"Custom Groups Claim": {
|
||||
IDToken: defaultIDToken,
|
||||
GroupsClaim: "other_groups",
|
||||
ExpectedUser: defaultIDToken.Subject,
|
||||
ExpectedEmail: defaultIDToken.Email,
|
||||
GroupsClaim: "roles",
|
||||
ExpectedUser: "123456789",
|
||||
ExpectedEmail: "janed@me.com",
|
||||
ExpectedGroups: []string{"test:c", "test:d"},
|
||||
},
|
||||
"Custom Groups Claim2": {
|
||||
IDToken: customGroupClaimIDToken,
|
||||
"Complex Groups Claim": {
|
||||
IDToken: complexGroupsIDToken,
|
||||
GroupsClaim: "groups",
|
||||
ExpectedUser: customGroupClaimIDToken.Subject,
|
||||
ExpectedEmail: customGroupClaimIDToken.Email,
|
||||
ExpectedUser: "123456789",
|
||||
ExpectedEmail: "complex@claims.com",
|
||||
ExpectedGroups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||
},
|
||||
}
|
||||
for testName, tc := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
jsonResp := []byte(fmt.Sprintf(`{"email":"%s"}`, profileURLEmail))
|
||||
server, provider := newTestSetup(jsonResp)
|
||||
server, provider := newTestOIDCSetup([]byte(`{}`))
|
||||
provider.GroupsClaim = tc.GroupsClaim
|
||||
defer server.Close()
|
||||
|
||||
|
|
@ -353,75 +561,10 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
|
|||
|
||||
assert.Equal(t, tc.ExpectedUser, ss.User)
|
||||
assert.Equal(t, tc.ExpectedEmail, ss.Email)
|
||||
assert.Equal(t, tc.ExpectedGroups, ss.Groups)
|
||||
assert.Equal(t, rawIDToken, ss.IDToken)
|
||||
assert.Equal(t, rawIDToken, ss.AccessToken)
|
||||
assert.Equal(t, tc.ExpectedGroups, ss.Groups)
|
||||
assert.Equal(t, "", ss.RefreshToken)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCProvider_findVerifiedIdToken(t *testing.T) {
|
||||
|
||||
server, provider := newTestSetup([]byte(""))
|
||||
|
||||
defer server.Close()
|
||||
|
||||
token := newOauth2Token()
|
||||
signedIDToken, _ := newSignedTestIDToken(defaultIDToken)
|
||||
tokenWithIDToken := token.WithExtra(map[string]interface{}{
|
||||
"id_token": signedIDToken,
|
||||
})
|
||||
|
||||
verifiedIDToken, err := provider.findVerifiedIDToken(context.Background(), tokenWithIDToken)
|
||||
assert.Equal(t, true, err == nil)
|
||||
if verifiedIDToken == nil {
|
||||
t.Fatal("verifiedIDToken is nil")
|
||||
}
|
||||
assert.Equal(t, defaultIDToken.Issuer, verifiedIDToken.Issuer)
|
||||
assert.Equal(t, defaultIDToken.Subject, verifiedIDToken.Subject)
|
||||
|
||||
// When the validation fails the response should be nil
|
||||
defaultIDToken.Id = "this-id-fails-validation"
|
||||
signedIDToken, _ = newSignedTestIDToken(defaultIDToken)
|
||||
tokenWithIDToken = token.WithExtra(map[string]interface{}{
|
||||
"id_token": signedIDToken,
|
||||
})
|
||||
|
||||
verifiedIDToken, err = provider.findVerifiedIDToken(context.Background(), tokenWithIDToken)
|
||||
assert.Equal(t, errors.New("failed to verify signature: the validation failed for subject [123456789]"), err)
|
||||
assert.Equal(t, true, verifiedIDToken == nil)
|
||||
|
||||
// When there is no id token in the oauth token
|
||||
verifiedIDToken, err = provider.findVerifiedIDToken(context.Background(), newOauth2Token())
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, true, verifiedIDToken == nil)
|
||||
}
|
||||
|
||||
func Test_formatGroup(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
RawGroup interface{}
|
||||
ExpectedFormattedGroupValue string
|
||||
}{
|
||||
"String Group": {
|
||||
RawGroup: "group",
|
||||
ExpectedFormattedGroupValue: "group",
|
||||
},
|
||||
"Map Group": {
|
||||
RawGroup: map[string]string{"id": "1", "name": "Test"},
|
||||
ExpectedFormattedGroupValue: "{\"id\":\"1\",\"name\":\"Test\"}",
|
||||
},
|
||||
"List Group": {
|
||||
RawGroup: []string{"First", "Second"},
|
||||
ExpectedFormattedGroupValue: "[\"First\",\"Second\"]",
|
||||
},
|
||||
}
|
||||
|
||||
for testName, tc := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
formattedGroup, err := formatGroup(tc.RawGroup)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, tc.ExpectedFormattedGroupValue, formattedGroup)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,23 @@
|
|||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
OIDCEmailClaim = "email"
|
||||
OIDCGroupsClaim = "groups"
|
||||
)
|
||||
|
||||
// ProviderData contains information required to configure all implementations
|
||||
|
|
@ -27,6 +38,11 @@ type ProviderData struct {
|
|||
ClientSecretFile string
|
||||
Scope string
|
||||
Prompt string
|
||||
|
||||
// Common OIDC options for any OIDC-based providers to consume
|
||||
AllowUnverifiedEmail bool
|
||||
EmailClaim string
|
||||
GroupsClaim string
|
||||
Verifier *oidc.IDTokenVerifier
|
||||
|
||||
// Universal Group authorization data structure
|
||||
|
|
@ -94,3 +110,116 @@ func defaultURL(u *url.URL, d *url.URL) *url.URL {
|
|||
}
|
||||
return &url.URL{}
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// These private OIDC helper methods are available to any providers that are
|
||||
// OIDC compliant
|
||||
// ****************************************************************************
|
||||
|
||||
// OIDCClaims is a struct to unmarshal the OIDC claims from an ID Token payload
|
||||
type OIDCClaims struct {
|
||||
Subject string `json:"sub"`
|
||||
Email string `json:"-"`
|
||||
Groups []string `json:"-"`
|
||||
Verified *bool `json:"email_verified"`
|
||||
|
||||
raw map[string]interface{}
|
||||
}
|
||||
|
||||
func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
|
||||
rawIDToken := getIDToken(token)
|
||||
if strings.TrimSpace(rawIDToken) == "" {
|
||||
return nil, ErrMissingIDToken
|
||||
}
|
||||
if p.Verifier == nil {
|
||||
return nil, ErrMissingOIDCVerifier
|
||||
}
|
||||
return p.Verifier.Verify(ctx, rawIDToken)
|
||||
}
|
||||
|
||||
// buildSessionFromClaims uses IDToken claims to populate a fresh SessionState
|
||||
// with non-Token related fields.
|
||||
func (p *ProviderData) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions.SessionState, error) {
|
||||
ss := &sessions.SessionState{}
|
||||
|
||||
if idToken == nil {
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
claims, err := p.getClaims(idToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err)
|
||||
}
|
||||
|
||||
ss.User = claims.Subject
|
||||
ss.Email = claims.Email
|
||||
ss.Groups = claims.Groups
|
||||
|
||||
// TODO (@NickMeves) Deprecate for dynamic claim to session mapping
|
||||
if pref, ok := claims.raw["preferred_username"].(string); ok {
|
||||
ss.PreferredUsername = pref
|
||||
}
|
||||
|
||||
// `email_verified` must be present and explicitly set to `false` to be
|
||||
// considered unverified.
|
||||
verifyEmail := (p.EmailClaim == OIDCEmailClaim) && !p.AllowUnverifiedEmail
|
||||
if verifyEmail && claims.Verified != nil && !*claims.Verified {
|
||||
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
|
||||
}
|
||||
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
// getClaims extracts IDToken claims into an OIDCClaims
|
||||
func (p *ProviderData) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) {
|
||||
claims := &OIDCClaims{}
|
||||
|
||||
// Extract default claims.
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse default id_token claims: %v", err)
|
||||
}
|
||||
// Extract custom claims.
|
||||
if err := idToken.Claims(&claims.raw); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse all id_token claims: %v", err)
|
||||
}
|
||||
|
||||
email := claims.raw[p.EmailClaim]
|
||||
if email != nil {
|
||||
claims.Email = fmt.Sprint(email)
|
||||
}
|
||||
claims.Groups = p.extractGroups(claims.raw)
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// extractGroups extracts groups from a claim to a list in a type safe manner.
|
||||
// If the claim isn't present, `nil` is returned. If the groups claim is
|
||||
// present but empty, `[]string{}` is returned.
|
||||
func (p *ProviderData) extractGroups(claims map[string]interface{}) []string {
|
||||
rawClaim, ok := claims[p.GroupsClaim]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle traditional list-based groups as well as non-standard singleton
|
||||
// based groups. Both variants support complex objects if needed.
|
||||
var claimGroups []interface{}
|
||||
switch raw := rawClaim.(type) {
|
||||
case []interface{}:
|
||||
claimGroups = raw
|
||||
case interface{}:
|
||||
claimGroups = []interface{}{raw}
|
||||
}
|
||||
|
||||
groups := []string{}
|
||||
for _, rawGroup := range claimGroups {
|
||||
formattedGroup, err := formatGroup(rawGroup)
|
||||
if err != nil {
|
||||
logger.Errorf("Warning: unable to format group of type %s with error %s",
|
||||
reflect.TypeOf(rawGroup), err)
|
||||
continue
|
||||
}
|
||||
groups = append(groups, formattedGroup)
|
||||
}
|
||||
return groups
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,437 @@
|
|||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
. "github.com/onsi/gomega"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
idToken = "eyJfoobar123.eyJbaz987.IDToken"
|
||||
accessToken = "eyJfoobar123.eyJbaz987.AccessToken"
|
||||
refreshToken = "eyJfoobar123.eyJbaz987.RefreshToken"
|
||||
|
||||
oidcIssuer = "https://issuer.example.com"
|
||||
oidcClientID = "https://test.myapp.com"
|
||||
oidcSecret = "SuperSecret123456789"
|
||||
|
||||
failureTokenID = "this-id-fails-verification"
|
||||
)
|
||||
|
||||
var (
|
||||
verified = true
|
||||
unverified = false
|
||||
|
||||
standardClaims = jwt.StandardClaims{
|
||||
Audience: oidcClientID,
|
||||
ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(),
|
||||
Id: "id-some-id",
|
||||
IssuedAt: time.Now().Unix(),
|
||||
Issuer: oidcIssuer,
|
||||
NotBefore: 0,
|
||||
Subject: "123456789",
|
||||
}
|
||||
|
||||
defaultIDToken = idTokenClaims{
|
||||
Name: "Jane Dobbs",
|
||||
Email: "janed@me.com",
|
||||
Phone: "+4798765432",
|
||||
Picture: "http://mugbook.com/janed/me.jpg",
|
||||
Groups: []string{"test:a", "test:b"},
|
||||
Roles: []string{"test:c", "test:d"},
|
||||
Verified: &verified,
|
||||
StandardClaims: standardClaims,
|
||||
}
|
||||
|
||||
complexGroupsIDToken = idTokenClaims{
|
||||
Name: "Complex Claim",
|
||||
Email: "complex@claims.com",
|
||||
Phone: "+5439871234",
|
||||
Picture: "http://mugbook.com/complex/claims.jpg",
|
||||
Groups: []map[string]interface{}{
|
||||
{
|
||||
"groupId": "Admin Group Id",
|
||||
"roles": []string{"Admin"},
|
||||
},
|
||||
},
|
||||
Roles: []string{"test:simple", "test:roles"},
|
||||
Verified: &verified,
|
||||
StandardClaims: standardClaims,
|
||||
}
|
||||
|
||||
unverifiedIDToken = idTokenClaims{
|
||||
Name: "Mystery Man",
|
||||
Email: "unverified@email.com",
|
||||
Phone: "+4025205729",
|
||||
Picture: "http://mugbook.com/unverified/email.jpg",
|
||||
Groups: []string{"test:a", "test:b"},
|
||||
Roles: []string{"test:c", "test:d"},
|
||||
Verified: &unverified,
|
||||
StandardClaims: standardClaims,
|
||||
}
|
||||
|
||||
minimalIDToken = idTokenClaims{
|
||||
StandardClaims: standardClaims,
|
||||
}
|
||||
)
|
||||
|
||||
type idTokenClaims struct {
|
||||
Name string `json:"preferred_username,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Phone string `json:"phone_number,omitempty"`
|
||||
Picture string `json:"picture,omitempty"`
|
||||
Groups interface{} `json:"groups,omitempty"`
|
||||
Roles interface{} `json:"roles,omitempty"`
|
||||
Verified *bool `json:"email_verified,omitempty"`
|
||||
jwt.StandardClaims
|
||||
}
|
||||
|
||||
type mockJWKS struct{}
|
||||
|
||||
func (mockJWKS) VerifySignature(_ context.Context, jwt string) ([]byte, error) {
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(strings.Split(jwt, ".")[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokenClaims := &idTokenClaims{}
|
||||
err = json.Unmarshal(decoded, tokenClaims)
|
||||
if err != nil || tokenClaims.Id == failureTokenID {
|
||||
return nil, fmt.Errorf("the validation failed for subject [%v]", tokenClaims.Subject)
|
||||
}
|
||||
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
func newSignedTestIDToken(tokenClaims idTokenClaims) (string, error) {
|
||||
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
standardClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, tokenClaims)
|
||||
return standardClaims.SignedString(key)
|
||||
}
|
||||
|
||||
func newTestOauth2Token() *oauth2.Token {
|
||||
return &oauth2.Token{
|
||||
AccessToken: accessToken,
|
||||
TokenType: "Bearer",
|
||||
RefreshToken: refreshToken,
|
||||
Expiry: time.Time{}.Add(time.Duration(5) * time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderData_verifyIDToken(t *testing.T) {
|
||||
failureIDToken := defaultIDToken
|
||||
failureIDToken.Id = failureTokenID
|
||||
|
||||
testCases := map[string]struct {
|
||||
IDToken *idTokenClaims
|
||||
Verifier bool
|
||||
ExpectIDToken bool
|
||||
ExpectedError error
|
||||
}{
|
||||
"Valid ID Token": {
|
||||
IDToken: &defaultIDToken,
|
||||
Verifier: true,
|
||||
ExpectIDToken: true,
|
||||
ExpectedError: nil,
|
||||
},
|
||||
"Invalid ID Token": {
|
||||
IDToken: &failureIDToken,
|
||||
Verifier: true,
|
||||
ExpectIDToken: false,
|
||||
ExpectedError: errors.New("failed to verify signature: the validation failed for subject [123456789]"),
|
||||
},
|
||||
"Missing ID Token": {
|
||||
IDToken: nil,
|
||||
Verifier: true,
|
||||
ExpectIDToken: false,
|
||||
ExpectedError: ErrMissingIDToken,
|
||||
},
|
||||
"OIDC Verifier not Configured": {
|
||||
IDToken: &defaultIDToken,
|
||||
Verifier: false,
|
||||
ExpectIDToken: false,
|
||||
ExpectedError: ErrMissingOIDCVerifier,
|
||||
},
|
||||
}
|
||||
|
||||
for testName, tc := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
token := newTestOauth2Token()
|
||||
if tc.IDToken != nil {
|
||||
idToken, err := newSignedTestIDToken(*tc.IDToken)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
token = token.WithExtra(map[string]interface{}{
|
||||
"id_token": idToken,
|
||||
})
|
||||
}
|
||||
|
||||
provider := &ProviderData{}
|
||||
if tc.Verifier {
|
||||
provider.Verifier = oidc.NewVerifier(
|
||||
oidcIssuer,
|
||||
mockJWKS{},
|
||||
&oidc.Config{ClientID: oidcClientID},
|
||||
)
|
||||
}
|
||||
verified, err := provider.verifyIDToken(context.Background(), token)
|
||||
if err != nil {
|
||||
g.Expect(err).To(Equal(tc.ExpectedError))
|
||||
}
|
||||
|
||||
if tc.ExpectIDToken {
|
||||
g.Expect(verified).ToNot(BeNil())
|
||||
g.Expect(*verified).To(BeAssignableToTypeOf(oidc.IDToken{}))
|
||||
} else {
|
||||
g.Expect(verified).To(BeNil())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderData_buildSessionFromClaims(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
IDToken idTokenClaims
|
||||
AllowUnverified bool
|
||||
EmailClaim string
|
||||
GroupsClaim string
|
||||
ExpectedError error
|
||||
ExpectedSession *sessions.SessionState
|
||||
}{
|
||||
"Standard": {
|
||||
IDToken: defaultIDToken,
|
||||
AllowUnverified: false,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "janed@me.com",
|
||||
Groups: []string{"test:a", "test:b"},
|
||||
PreferredUsername: "Jane Dobbs",
|
||||
},
|
||||
},
|
||||
"Unverified Denied": {
|
||||
IDToken: unverifiedIDToken,
|
||||
AllowUnverified: false,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ExpectedError: errors.New("email in id_token (unverified@email.com) isn't verified"),
|
||||
},
|
||||
"Unverified Allowed": {
|
||||
IDToken: unverifiedIDToken,
|
||||
AllowUnverified: true,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "unverified@email.com",
|
||||
Groups: []string{"test:a", "test:b"},
|
||||
PreferredUsername: "Mystery Man",
|
||||
},
|
||||
},
|
||||
"Complex Groups": {
|
||||
IDToken: complexGroupsIDToken,
|
||||
AllowUnverified: true,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "groups",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "complex@claims.com",
|
||||
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
|
||||
PreferredUsername: "Complex Claim",
|
||||
},
|
||||
},
|
||||
"Email Claim Switched": {
|
||||
IDToken: unverifiedIDToken,
|
||||
AllowUnverified: true,
|
||||
EmailClaim: "phone_number",
|
||||
GroupsClaim: "groups",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "+4025205729",
|
||||
Groups: []string{"test:a", "test:b"},
|
||||
PreferredUsername: "Mystery Man",
|
||||
},
|
||||
},
|
||||
"Email Claim Switched to Non String": {
|
||||
IDToken: unverifiedIDToken,
|
||||
AllowUnverified: true,
|
||||
EmailClaim: "roles",
|
||||
GroupsClaim: "groups",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "[test:c test:d]",
|
||||
Groups: []string{"test:a", "test:b"},
|
||||
PreferredUsername: "Mystery Man",
|
||||
},
|
||||
},
|
||||
"Email Claim Non Existent": {
|
||||
IDToken: unverifiedIDToken,
|
||||
AllowUnverified: true,
|
||||
EmailClaim: "aksjdfhjksadh",
|
||||
GroupsClaim: "groups",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "",
|
||||
Groups: []string{"test:a", "test:b"},
|
||||
PreferredUsername: "Mystery Man",
|
||||
},
|
||||
},
|
||||
"Groups Claim Switched": {
|
||||
IDToken: defaultIDToken,
|
||||
AllowUnverified: false,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "roles",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "janed@me.com",
|
||||
Groups: []string{"test:c", "test:d"},
|
||||
PreferredUsername: "Jane Dobbs",
|
||||
},
|
||||
},
|
||||
"Groups Claim Non Existent": {
|
||||
IDToken: defaultIDToken,
|
||||
AllowUnverified: false,
|
||||
EmailClaim: "email",
|
||||
GroupsClaim: "alskdjfsalkdjf",
|
||||
ExpectedSession: &sessions.SessionState{
|
||||
User: "123456789",
|
||||
Email: "janed@me.com",
|
||||
Groups: nil,
|
||||
PreferredUsername: "Jane Dobbs",
|
||||
},
|
||||
},
|
||||
}
|
||||
for testName, tc := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
provider := &ProviderData{
|
||||
Verifier: oidc.NewVerifier(
|
||||
oidcIssuer,
|
||||
mockJWKS{},
|
||||
&oidc.Config{ClientID: oidcClientID},
|
||||
),
|
||||
}
|
||||
provider.AllowUnverifiedEmail = tc.AllowUnverified
|
||||
provider.EmailClaim = tc.EmailClaim
|
||||
provider.GroupsClaim = tc.GroupsClaim
|
||||
|
||||
rawIDToken, err := newSignedTestIDToken(tc.IDToken)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
ss, err := provider.buildSessionFromClaims(idToken)
|
||||
if err != nil {
|
||||
g.Expect(err).To(Equal(tc.ExpectedError))
|
||||
}
|
||||
if ss != nil {
|
||||
g.Expect(ss).To(Equal(tc.ExpectedSession))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderData_extractGroups(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
Claims map[string]interface{}
|
||||
GroupsClaim string
|
||||
ExpectedGroups []string
|
||||
}{
|
||||
"Standard String Groups": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
"groups": []interface{}{"three", "string", "groups"},
|
||||
},
|
||||
GroupsClaim: "groups",
|
||||
ExpectedGroups: []string{"three", "string", "groups"},
|
||||
},
|
||||
"Different Claim Name": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
"roles": []interface{}{"three", "string", "roles"},
|
||||
},
|
||||
GroupsClaim: "roles",
|
||||
ExpectedGroups: []string{"three", "string", "roles"},
|
||||
},
|
||||
"Numeric Groups": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
"groups": []interface{}{1, 2, 3},
|
||||
},
|
||||
GroupsClaim: "groups",
|
||||
ExpectedGroups: []string{"1", "2", "3"},
|
||||
},
|
||||
"Complex Groups": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
"groups": []interface{}{
|
||||
map[string]interface{}{
|
||||
"groupId": "Admin Group Id",
|
||||
"roles": []string{"Admin"},
|
||||
},
|
||||
12345,
|
||||
"Just::A::String",
|
||||
},
|
||||
},
|
||||
GroupsClaim: "groups",
|
||||
ExpectedGroups: []string{
|
||||
"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}",
|
||||
"12345",
|
||||
"Just::A::String",
|
||||
},
|
||||
},
|
||||
"Missing Groups Claim Returns Nil": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
},
|
||||
GroupsClaim: "groups",
|
||||
ExpectedGroups: nil,
|
||||
},
|
||||
"Non List Groups": {
|
||||
Claims: map[string]interface{}{
|
||||
"email": "this@does.not.matter.com",
|
||||
"groups": "singleton",
|
||||
},
|
||||
GroupsClaim: "groups",
|
||||
ExpectedGroups: []string{"singleton"},
|
||||
},
|
||||
}
|
||||
for testName, tc := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
provider := &ProviderData{
|
||||
Verifier: oidc.NewVerifier(
|
||||
oidcIssuer,
|
||||
mockJWKS{},
|
||||
&oidc.Config{ClientID: oidcClientID},
|
||||
),
|
||||
}
|
||||
provider.GroupsClaim = tc.GroupsClaim
|
||||
|
||||
groups := provider.extractGroups(tc.Claims)
|
||||
if tc.ExpectedGroups != nil {
|
||||
g.Expect(groups).To(Equal(tc.ExpectedGroups))
|
||||
} else {
|
||||
g.Expect(groups).To(BeNil())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -22,6 +22,14 @@ var (
|
|||
// code
|
||||
ErrMissingCode = errors.New("missing code")
|
||||
|
||||
// ErrMissingIDToken is returned when an oidc.Token does not contain the
|
||||
// extra `id_token` field for an IDToken.
|
||||
ErrMissingIDToken = errors.New("missing id_token")
|
||||
|
||||
// ErrMissingOIDCVerifier is returned when a provider didn't set `Verifier`
|
||||
// but an attempt to call `Verifier.Verify` was about to be made.
|
||||
ErrMissingOIDCVerifier = errors.New("oidc verifier is not configured")
|
||||
|
||||
_ Provider = (*ProviderData)(nil)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/bitly/go-simplejson"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -55,3 +59,42 @@ func makeLoginURL(p *ProviderData, redirectURI, state string, extraParams url.Va
|
|||
a.RawQuery = params.Encode()
|
||||
return a
|
||||
}
|
||||
|
||||
// getIDToken extracts an IDToken stored in the `Extra` fields of an
|
||||
// oauth2.Token
|
||||
func getIDToken(token *oauth2.Token) string {
|
||||
idToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return idToken
|
||||
}
|
||||
|
||||
// formatGroup coerces an OIDC groups claim into a string
|
||||
// If it is non-string, marshal it into JSON.
|
||||
func formatGroup(rawGroup interface{}) (string, error) {
|
||||
if group, ok := rawGroup.(string); ok {
|
||||
return group, nil
|
||||
}
|
||||
|
||||
jsonGroup, err := json.Marshal(rawGroup)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(jsonGroup), nil
|
||||
}
|
||||
|
||||
// coerceArray extracts a field from simplejson.Json that might be a
|
||||
// singleton or a list and coerces it into a list.
|
||||
func coerceArray(sj *simplejson.Json, key string) []interface{} {
|
||||
array, err := sj.Get(key).Array()
|
||||
if err == nil {
|
||||
return array
|
||||
}
|
||||
|
||||
single := sj.Get(key).Interface()
|
||||
if single == nil {
|
||||
return nil
|
||||
}
|
||||
return []interface{}{single}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,9 +5,10 @@ import (
|
|||
"testing"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func TestMakeAuhtorizationHeader(t *testing.T) {
|
||||
func Test_makeAuthorizationHeader(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
prefix string
|
||||
|
|
@ -64,3 +65,49 @@ func TestMakeAuhtorizationHeader(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getIDToken(t *testing.T) {
|
||||
const idToken = "eyJfoobar.eyJfoobar.12345asdf"
|
||||
g := NewWithT(t)
|
||||
|
||||
token := &oauth2.Token{}
|
||||
g.Expect(getIDToken(token)).To(Equal(""))
|
||||
|
||||
extraToken := token.WithExtra(map[string]interface{}{
|
||||
"id_token": idToken,
|
||||
})
|
||||
g.Expect(getIDToken(extraToken)).To(Equal(idToken))
|
||||
}
|
||||
|
||||
func Test_formatGroup(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
rawGroup interface{}
|
||||
expected string
|
||||
}{
|
||||
"String Group": {
|
||||
rawGroup: "group",
|
||||
expected: "group",
|
||||
},
|
||||
"Numeric Group": {
|
||||
rawGroup: 123,
|
||||
expected: "123",
|
||||
},
|
||||
"Map Group": {
|
||||
rawGroup: map[string]string{"id": "1", "name": "Test"},
|
||||
expected: "{\"id\":\"1\",\"name\":\"Test\"}",
|
||||
},
|
||||
"List Group": {
|
||||
rawGroup: []string{"First", "Second"},
|
||||
expected: "[\"First\",\"Second\"]",
|
||||
},
|
||||
}
|
||||
|
||||
for testName, tc := range testCases {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
formattedGroup, err := formatGroup(tc.rawGroup)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
g.Expect(formattedGroup).To(Equal(tc.expected))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue