Merge remote-tracking branch 'origin2/master'

This commit is contained in:
Kevin Kreitner 2021-01-19 16:23:46 +01:00
commit e7919f0535
47 changed files with 2704 additions and 1044 deletions

View File

@ -4,7 +4,21 @@
## Important Notes
- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) Redirect URL generation will attempt secondary strategies
in the priority chain if any fail the `IsValidRedirect` security check. Previously any failures fell back to `/`.
- [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Keycloak will now use `--profile-url` if set for the userinfo endpoint
instead of `--validate-url`. `--validate-url` will still work for backwards compatibility.
- [#957](https://github.com/oauth2-proxy/oauth2-proxy/pull/957) To use X-Forwarded-{Proto,Host,Uri} on redirect detection, `--reverse-proxy` must be `true`.
- [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) `--user-id-claim` option is deprecated and replaced by `--oidc-email-claim`
- [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Gitlab projects needs a Gitlab application with the extra `read_api` enabled
- [#849](https://github.com/oauth2-proxy/oauth2-proxy/pull/849) `/oauth2/auth` `allowed_groups` querystring parameter can be paired with the `allowed-groups` configuration option.
- The `allowed_groups` querystring parameter can specify multiple comma delimited groups.
- In this scenario, the user must have a group (from their multiple groups) present in both lists to not get a 401 or 403 response code.
- Example:
- OAuth2-Proxy globally sets the `allowed_groups` as `engineering`.
- An application using Kubernetes ingress uses the `/oauth2/auth` endpoint with `allowed_groups` querystring set to `backend`.
- A user must have a session with the groups `["engineering", "backend"]` to pass authorization.
- Another user with the groups `["engineering", "frontend"]` would fail the querystring authorization portion.
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Existing sessions from v6.0.0 or earlier are no longer valid. They will trigger a reauthentication.
- [#826](https://github.com/oauth2-proxy/oauth2-proxy/pull/826) `skip-auth-strip-headers` now applies to all requests, not just those where authentication would be skipped.
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) The behavior of the Google provider Groups restriction changes with this
@ -18,11 +32,19 @@
- [#575](https://github.com/oauth2-proxy/oauth2-proxy/pull/575) Sessions from v5.1.1 or earlier will no longer validate since they were not signed with SHA1.
- Sessions from v6.0.0 or later had a graceful conversion to SHA256 that resulted in no reauthentication
- Upgrading from v5.1.1 or earlier will result in a reauthentication
- [#616](https://github.com/oauth2-proxy/oauth2-proxy/pull/616) Ensure you have configured oauth2-proxy to use the `groups` scope. The user may be logged out initially as they may not currently have the `groups` claim however after going back through login process wil be authenticated.
- [#616](https://github.com/oauth2-proxy/oauth2-proxy/pull/616) Ensure you have configured oauth2-proxy to use the `groups` scope.
- The user may be logged out initially as they may not currently have the `groups` claim however after going back through login process wil be authenticated.
- [#839](https://github.com/oauth2-proxy/oauth2-proxy/pull/839) Enables complex data structures for group claim entries, which are output as Json by default.
## Breaking Changes
- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) `--reverse-proxy` must be true to trust `X-Forwarded-*` headers as canonical.
These are used throughout the application in redirect URLs, cookie domains and host logging logic. These are the headers:
- `X-Forwarded-Proto` instead of `req.URL.Scheme`
- `X-Forwarded-Host` instead of `req.Host`
- `X-Forwarded-Uri` instead of `req.URL.RequestURI()`
- [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) In config files & envvar configs, `keycloak_group` is now the plural `keycloak_groups`.
Flag configs are still `--keycloak-group` but it can be passed multiple times.
- [#911](https://github.com/oauth2-proxy/oauth2-proxy/pull/911) Specifying a non-existent provider will cause OAuth2-Proxy to fail on startup instead of defaulting to "google".
- [#797](https://github.com/oauth2-proxy/oauth2-proxy/pull/797) Security changes to Google provider group authorization flow
- If you change the list of allowed groups, existing sessions that now don't have a valid group will be logged out immediately.
@ -44,15 +66,23 @@
## Changes since v6.1.1
- [#995](https://github.com/oauth2-proxy/oauth2-proxy/pull/995) Add Security Policy (@JoelSpeed)
- [#964](https://github.com/oauth2-proxy/oauth2-proxy/pull/964) Require `--reverse-proxy` true to trust `X-Forwareded-*` type headers (@NickMeves)
- [#970](https://github.com/oauth2-proxy/oauth2-proxy/pull/970) Fix joined cookie name for those containing underline in the suffix (@peppered)
- [#953](https://github.com/oauth2-proxy/oauth2-proxy/pull/953) Migrate Keycloak to EnrichSession & support multiple groups for authorization (@NickMeves)
- [#957](https://github.com/oauth2-proxy/oauth2-proxy/pull/957) Use X-Forwarded-{Proto,Host,Uri} on redirect as last resort (@linuxgemini)
- [#630](https://github.com/oauth2-proxy/oauth2-proxy/pull/630) Add support for Gitlab project based authentication (@factorysh)
- [#907](https://github.com/oauth2-proxy/oauth2-proxy/pull/907) Introduce alpha configuration option to enable testing of structured configuration (@JoelSpeed)
- [#938](https://github.com/oauth2-proxy/oauth2-proxy/pull/938) Cleanup missed provider renaming refactor methods (@NickMeves)
- [#816](https://github.com/oauth2-proxy/oauth2-proxy/pull/816) (via [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936)) Support non-list group claims (@loafoe)
- [#936](https://github.com/oauth2-proxy/oauth2-proxy/pull/936) Refactor OIDC Provider and support groups from Profile URL (@NickMeves)
- [#869](https://github.com/oauth2-proxy/oauth2-proxy/pull/869) Streamline provider interface method names and signatures (@NickMeves)
- [#849](https://github.com/oauth2-proxy/oauth2-proxy/pull/849) Support group authorization on `oauth2/auth` endpoint via `allowed_groups` querystring (@NickMeves)
- [#925](https://github.com/oauth2-proxy/oauth2-proxy/pull/925) Fix basic auth legacy header conversion (@JoelSpeed)
- [#916](https://github.com/oauth2-proxy/oauth2-proxy/pull/916) Add AlphaOptions struct to prepare for alpha config loading (@JoelSpeed)
- [#923](https://github.com/oauth2-proxy/oauth2-proxy/pull/923) Support TLS 1.3 (@aajisaka)
- [#918](https://github.com/oauth2-proxy/oauth2-proxy/pull/918) Fix log header output (@JoelSpeed)
- [#911](https://github.com/oauth2-proxy/oauth2-proxy/pull/911) Validate provider type on startup.
- [#869](https://github.com/oauth2-proxy/oauth2-proxy/pull/869) Streamline provider interface method names and signatures (@NickMeves)
- [#906](https://github.com/oauth2-proxy/oauth2-proxy/pull/906) Set up v6.1.x versioned documentation as default documentation (@JoelSpeed)
- [#905](https://github.com/oauth2-proxy/oauth2-proxy/pull/905) Remove v5 legacy sessions support (@NickMeves)
- [#904](https://github.com/oauth2-proxy/oauth2-proxy/pull/904) Set `skip-auth-strip-headers` to `true` by default (@NickMeves)
@ -79,6 +109,7 @@
- [#750](https://github.com/oauth2-proxy/oauth2-proxy/pull/750) ci: Migrate to Github Actions (@shinebayar-g)
- [#829](https://github.com/oauth2-proxy/oauth2-proxy/pull/820) Rename test directory to testdata (@johejo)
- [#819](https://github.com/oauth2-proxy/oauth2-proxy/pull/819) Improve CI (@johejo)
- [#989](https://github.com/oauth2-proxy/oauth2-proxy/pull/989) Adapt isAjax to support mimetype lists (@rassie)
# v6.1.1

View File

@ -1,2 +1,3 @@
Joel Speed <joel.speed@hotmail.co.uk> (@JoelSpeed)
Henry Jenkins <henry@henryjenkins.name> (@steakunderscore)
Nick Meves <nick.meves@greenhouse.io> (@NickMeves)

3
SECURITY.md Normal file
View File

@ -0,0 +1,3 @@
# Security Disclosures
Please see [our community docs](https://oauth2-proxy.github.io/oauth2-proxy/docs/community/security) for our security policy.

View File

@ -0,0 +1,49 @@
---
id: security
title: Security
---
:::note
OAuth2 Proxy is a community project.
Maintainers do not work on this project full time, and as such,
while we endeavour to respond to disclosures as quickly as possible,
this may take longer than in projects with corporate sponsorship.
:::
## Security Disclosures
:::important
If you believe you have found a vulnerability within OAuth2 Proxy or any of its
dependencies, please do NOT open an issue or PR on GitHub, please do NOT post
any details publicly.
:::
Security disclosures MUST be done in private.
If you have found an issue that you would like to bring to the attention of the
maintenance team for OAuth2 Proxy, please compose an email and send it to the
list of maintainers in our [MAINTAINERS](https://github.com/oauth2-proxy/oauth2-proxy/blob/master/MAINTAINERS) file.
Please include as much detail as possible.
Ideally, your disclosure should include:
- A reproducible case that can be used to demonstrate the exploit
- How you discovered this vulnerability
- A potential fix for the issue (if you have thought of one)
- Versions affected (if not present in master)
- Your GitHub ID
### How will we respond to disclosures?
We use [GitHub Security Advisories](https://docs.github.com/en/github/managing-security-vulnerabilities/about-github-security-advisories)
to privately discuss fixes for disclosed vulnerabilities.
If you include a GitHub ID with your disclosure we will add you as a collaborator
for the advisory so that you can join the discussion and validate any fixes
we may propose.
For minor issues and previously disclosed vulnerabilities (typically for
dependencies), we may use regular PRs for fixes and forego the security advisory.
Once a fix has been agreed upon, we will merge the fix and create a new release.
If we have multiple security issues in flight simultaneously, we may delay
merging fixes until all patches are ready.
We may also backport the fix to previous releases,
but this will be at the discretion of the maintainers.

View File

@ -135,15 +135,25 @@ If you are using GitHub enterprise, make sure you set the following to the appro
Make sure you set the following to the appropriate url:
-provider=keycloak
-client-id=<client you have created>
-client-secret=<your client's secret>
-login-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/auth"
-redeem-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/token"
-validate-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/userinfo"
-keycloak-group=<user_group>
--provider=keycloak
--client-id=<client you have created>
--client-secret=<your client's secret>
--login-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/auth"
--redeem-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/token"
--profile-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/userinfo"
--validate-url="http(s)://<keycloak host>/auth/realms/<your realm>/protocol/openid-connect/userinfo"
--keycloak-group=<first_allowed_user_group>
--keycloak-group=<second_allowed_user_group>
The group management in keycloak is using a tree. If you create a group named admin in keycloak you should define the 'keycloak-group' value to /admin.
For group based authorization, the optional `--keycloak-group` (legacy) or `--allowed-group` (global standard)
flags can be used to specify which groups to limit access to.
If these are unset but a `groups` mapper is set up above in step (3), the provider will still
populate the `X-Forwarded-Groups` header to your upstream server with the `groups` data in the
Keycloak userinfo endpoint response.
The group management in keycloak is using a tree. If you create a group named admin in keycloak
you should define the 'keycloak-group' value to /admin.
### GitLab Auth Provider

View File

@ -74,7 +74,8 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/
| `--insecure-oidc-skip-issuer-verification` | bool | allow the OIDC issuer URL to differ from the expected (currently required for Azure multi-tenant compatibility) | false |
| `--oidc-issuer-url` | string | the OpenID Connect issuer URL, e.g. `"https://accounts.google.com"` | |
| `--oidc-jwks-url` | string | OIDC JWKS URI for token verification; required if OIDC discovery is disabled | |
| `--oidc-groups-claim` | string | which claim contains the user groups | `"groups"` |
| `--oidc-email-claim` | string | which OIDC claim contains the user's email | `"email"` |
| `--oidc-groups-claim` | string | which OIDC claim contains the user groups | `"groups"` |
| `--pass-access-token` | bool | pass OAuth access_token to upstream via X-Forwarded-Access-Token header. When used with `--set-xauthrequest` this adds the X-Auth-Request-Access-Token header to the response | false |
| `--pass-authorization-header` | bool | pass OIDC IDToken to upstream via Authorization Bearer header | false |
| `--pass-basic-auth` | bool | pass HTTP Basic Auth, X-Forwarded-User, X-Forwarded-Email and X-Forwarded-Preferred-Username information to upstream | true |
@ -105,7 +106,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/
| `--request-logging` | bool | Log requests | true |
| `--request-logging-format` | string | Template for request log lines | see [Logging Configuration](#logging-configuration) |
| `--resource` | string | The resource that is protected (Azure AD only) | |
| `--reverse-proxy` | bool | are we running behind a reverse proxy, controls whether headers like X-Real-IP are accepted | false |
| `--reverse-proxy` | bool | are we running behind a reverse proxy, controls whether headers like X-Real-IP are accepted and allows X-Forwarded-{Proto,Host,Uri} headers to be used on redirect selection | false |
| `--scope` | string | OAuth scope specification | |
| `--session-cookie-minimal` | bool | strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only) | false |
| `--session-store-type` | string | [Session data storage backend](sessions.md); redis or cookie | cookie |
@ -128,7 +129,6 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/
| `--tls-cert-file` | string | path to certificate file | |
| `--tls-key-file` | string | path to private key file | |
| `--upstream` | string \| list | the http url(s) of the upstream endpoint, file:// paths for static files or `static://<status_code>` for static response. Routing is based on the path | |
| `--user-id-claim` | string | which claim contains the user ID | \["email"\] |
| `--allowed-group` | string \| list | restrict logins to members of this group (may be given multiple times) | |
| `--validate-url` | string | Access token validation endpoint | |
| `--version` | n/a | print version string | |
@ -354,6 +354,73 @@ It is recommended to use `--session-store-type=redis` when expecting large sessi
You have to substitute *name* with the actual cookie name you configured via --cookie-name parameter. If you don't set a custom cookie name the variable should be "$upstream_cookie__oauth2_proxy_1" instead of "$upstream_cookie_name_1" and the new cookie-name should be "_oauth2_proxy_1=" instead of "name_1=".
## Configuring for use with the Traefik (v2) `ForwardAuth` middleware
**This option requires `--reverse-proxy` option to be set.**
The [Traefik v2 `ForwardAuth` middleware](https://doc.traefik.io/traefik/middlewares/forwardauth/) allows Traefik to authenticate requests via the oauth2-proxy's `/oauth2/auth` endpoint on every request, which only returns a 202 Accepted response or a 401 Unauthorized response without proxying the whole request through. For example, on Dynamic File (YAML) Configuration:
```yaml
http:
routers:
a-service:
rule: "Host(`a-service.example.com`)"
service: a-service-backend
middlewares:
- oauth-errors
- oauth-auth
tls:
certResolver: default
domains:
- main: "example.com"
sans:
- "*.example.com"
oauth:
rule: "Host(`a-service.example.com`, `oauth.example.com`) && PathPrefix(`/oauth2/`)"
middlewares:
- auth-headers
service: oauth-backend
tls:
certResolver: default
domains:
- main: "example.com"
sans:
- "*.example.com"
services:
a-service-backend:
loadBalancer:
servers:
- url: http://172.16.0.2:7555
oauth-backend:
loadBalancer:
servers:
- url: http://172.16.0.1:4180
middlewares:
auth-headers:
headers:
sslRedirect: true
stsSeconds: 315360000
browserXssFilter: true
contentTypeNosniff: true
forceSTSHeader: true
sslHost: example.com
stsIncludeSubdomains: true
stsPreload: true
frameDeny: true
oauth-auth:
forwardAuth:
address: https://oauth.example.com/oauth2/auth
trustForwardHeader: true
oauth-errors:
errors:
status:
- "401-403"
service: oauth-backend
query: "/oauth2/sign_in"
```
:::note
If you set up your OAuth2 provider to rotate your client secret, you can use the `client-secret-file` option to reload the secret when it is updated.
:::

View File

@ -20,5 +20,11 @@ module.exports = {
collapsed: false,
items: ['features/endpoints', 'features/request_signatures'],
},
{
type: 'category',
label: 'Community',
collapsed: false,
items: ['community/security'],
},
],
};

View File

@ -0,0 +1,49 @@
---
id: security
title: Security
---
:::note
OAuth2 Proxy is a community project.
Maintainers do not work on this project full time, and as such,
while we endeavour to respond to disclosures as quickly as possible,
this may take longer than in projects with corporate sponsorship.
:::
## Security Disclosures
:::important
If you believe you have found a vulnerability within OAuth2 Proxy or any of its
dependencies, please do NOT open an issue or PR on GitHub, please do NOT post any
details publicly.
:::
Security disclosures MUST be done in private.
If you have found an issue that you would like to bring to the attention of the
maintenance team for OAuth2 Proxy, please compose an email and send it to the
list of maintainers in our [MAINTAINERS](https://github.com/oauth2-proxy/oauth2-proxy/blob/master/MAINTAINERS) file.
Please include as much detail as possible.
Ideally, your disclosure should include:
- A reproducible case that can be used to demonstrate the exploit
- How you discovered this vulnerability
- A potential fix for the issue (if you have thought of one)
- Versions affected (if not present in master)
- Your GitHub ID
### How will we respond to disclosures?
We use [GitHub Security Advisories](https://docs.github.com/en/github/managing-security-vulnerabilities/about-github-security-advisories)
to privately discuss fixes for disclosed vulnerabilities.
If you include a GitHub ID with your disclosure we will add you as a collaborator
for the advisory so that you can join the discussion and validate any fixes
we may propose.
For minor issues and previously disclosed vulnerabilities (typically for
dependencies), we may use regular PRs for fixes and forego the security advisory.
Once a fix has been agreed upon, we will merge the fix and create a new release.
If we have multiple security issues in flight simultaneously, we may delay
merging fixes until all patches are ready.
We may also backport the fix to previous releases,
but this will be at the discretion of the maintainers.

View File

@ -45,6 +45,17 @@
"id": "version-6.1.x/features/request_signatures"
}
]
},
{
"collapsed": false,
"type": "category",
"label": "Community",
"items": [
{
"type": "doc",
"id": "version-6.1.x/community/security"
}
]
}
]
}

View File

@ -24,16 +24,14 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/upstream"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
"github.com/oauth2-proxy/oauth2-proxy/v7/providers"
)
const (
httpScheme = "http"
httpsScheme = "https"
schemeHTTPS = "https"
applicationJSON = "application/json"
)
@ -229,7 +227,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
// the OAuth2 Proxy authentication logic kicks in.
// For example forcing HTTPS or health checks.
func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
chain := alice.New(middleware.NewScope())
chain := alice.New(middleware.NewScope(opts.ReverseProxy))
if opts.ForceHTTPS {
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
@ -366,49 +364,6 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) {
return routes, nil
}
// GetRedirectURI returns the redirectURL that the upstream OAuth Provider will
// redirect clients to once authenticated
func (p *OAuthProxy) GetRedirectURI(host string) string {
// default to the request Host if not set
if p.redirectURL.Host != "" {
return p.redirectURL.String()
}
u := *p.redirectURL
if u.Scheme == "" {
if p.CookieSecure {
u.Scheme = httpsScheme
} else {
u.Scheme = httpScheme
}
}
u.Host = host
return u.String()
}
func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (*sessionsapi.SessionState, error) {
if code == "" {
return nil, providers.ErrMissingCode
}
redirectURI := p.GetRedirectURI(host)
s, err := p.provider.Redeem(ctx, redirectURI, code)
if err != nil {
return nil, err
}
return s, nil
}
func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error {
var err error
if s.Email == "" {
s.Email, err = p.provider.GetEmailAddress(ctx, s)
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
return err
}
}
return p.provider.EnrichSession(ctx, s)
}
// MakeCSRFCookie creates a cookie for CSRF
func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return p.makeCookie(req, p.CSRFCookieName, value, expiration, now)
@ -418,7 +373,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex
cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains)
if cookieDomain != "" {
domain := util.GetRequestHost(req)
domain := requestutil.GetRequestHost(req)
if h, _, err := net.SplitHostPort(domain); err == nil {
domain = h
}
@ -466,6 +421,81 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *s
return p.sessionStore.Save(rw, req, s)
}
// IsValidRedirect checks whether the redirect URL is whitelisted
func (p *OAuthProxy) IsValidRedirect(redirect string) bool {
switch {
case redirect == "":
// The user didn't specify a redirect, should fallback to `/`
return false
case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect):
return true
case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"):
redirectURL, err := url.Parse(redirect)
if err != nil {
logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect)
return false
}
redirectHostname := redirectURL.Hostname()
for _, domain := range p.whitelistDomains {
domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, "."))
if domainHostname == "" {
continue
}
if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) {
// the domain names match, now validate the ports
// if the whitelisted domain's port is '*', allow all ports
// if the whitelisted domain contains a specific port, only allow that port
// if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https
redirectPort := redirectURL.Port()
if (domainPort == "*") ||
(domainPort == redirectPort) ||
(domainPort == "" && redirectPort == "") {
return true
}
}
}
logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect)
return false
default:
logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect)
return false
}
}
func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req)
}
func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) {
prepareNoCache(rw)
}
switch path := req.URL.Path; {
case path == p.RobotsPath:
p.RobotsTxt(rw)
case p.IsAllowedRequest(req):
p.SkipAuthProxy(rw, req)
case path == p.SignInPath:
p.SignIn(rw, req)
case path == p.SignOutPath:
p.SignOut(rw, req)
case path == p.OAuthStartPath:
p.OAuthStart(rw, req)
case path == p.OAuthCallbackPath:
p.OAuthCallback(rw, req)
case path == p.AuthOnlyPath:
p.AuthOnly(rw, req)
case path == p.UserInfoPath:
p.UserInfo(rw, req)
default:
p.Proxy(rw, req)
}
}
// RobotsTxt disallows scraping pages from the OAuthProxy
func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) {
_, err := fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
@ -496,6 +526,42 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m
}
}
// IsAllowedRequest is used to check if auth should be skipped for this request
func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool {
isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS"
return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.isTrustedIP(req)
}
// IsAllowedRoute is used to check if the request method & path is allowed without auth
func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool {
for _, route := range p.allowedRoutes {
if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) {
return true
}
}
return false
}
// isTrustedIP is used to check if a request comes from a trusted client IP address.
func (p *OAuthProxy) isTrustedIP(req *http.Request) bool {
if p.trustedIPs == nil {
return false
}
remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req)
if err != nil {
logger.Errorf("Error obtaining real IP for trusted IP list: %v", err)
// Possibly spoofed X-Real-IP header
return false
}
if remoteAddr == nil {
return false
}
return p.trustedIPs.Has(remoteAddr)
}
// SignInPage writes the sing in template to the response
func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
prepareNoCache(rw)
@ -507,7 +573,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
}
rw.WriteHeader(code)
redirectURL, err := p.GetRedirect(req)
redirectURL, err := p.getAppRedirect(req)
if err != nil {
logger.Errorf("Error obtaining redirect: %v", err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
@ -566,195 +632,9 @@ func (p *OAuthProxy) ManualSignIn(req *http.Request) (string, bool) {
return "", false
}
// GetRedirect reads the query parameter to get the URL to redirect clients to
// once authenticated with the OAuthProxy
func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) {
err = req.ParseForm()
if err != nil {
return
}
redirect = req.Header.Get("X-Auth-Request-Redirect")
if req.Form.Get("rd") != "" {
redirect = req.Form.Get("rd")
}
if !p.IsValidRedirect(redirect) {
// Use RequestURI to preserve ?query
redirect = req.URL.RequestURI()
if strings.HasPrefix(redirect, p.ProxyPrefix) {
redirect = "/"
}
}
return
}
// splitHostPort separates host and port. If the port is not valid, it returns
// the entire input as host, and it doesn't check the validity of the host.
// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
// *** taken from net/url, modified validOptionalPort() to accept ":*"
func splitHostPort(hostport string) (host, port string) {
host = hostport
colon := strings.LastIndexByte(host, ':')
if colon != -1 && validOptionalPort(host[colon:]) {
host, port = host[:colon], host[colon+1:]
}
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
host = host[1 : len(host)-1]
}
return
}
// validOptionalPort reports whether port is either an empty string
// or matches /^:\d*$/
// *** taken from net/url, modified to accept ":*"
func validOptionalPort(port string) bool {
if port == "" || port == ":*" {
return true
}
if port[0] != ':' {
return false
}
for _, b := range port[1:] {
if b < '0' || b > '9' {
return false
}
}
return true
}
// IsValidRedirect checks whether the redirect URL is whitelisted
func (p *OAuthProxy) IsValidRedirect(redirect string) bool {
switch {
case redirect == "":
// The user didn't specify a redirect, should fallback to `/`
return false
case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect):
return true
case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"):
redirectURL, err := url.Parse(redirect)
if err != nil {
logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect)
return false
}
redirectHostname := redirectURL.Hostname()
for _, domain := range p.whitelistDomains {
domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, "."))
if domainHostname == "" {
continue
}
if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) {
// the domain names match, now validate the ports
// if the whitelisted domain's port is '*', allow all ports
// if the whitelisted domain contains a specific port, only allow that port
// if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https
redirectPort := redirectURL.Port()
if (domainPort == "*") ||
(domainPort == redirectPort) ||
(domainPort == "" && redirectPort == "") {
return true
}
}
}
logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect)
return false
default:
logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect)
return false
}
}
// IsAllowedRequest is used to check if auth should be skipped for this request
func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool {
isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS"
return isPreflightRequestAllowed || p.isAllowedRoute(req) || p.IsTrustedIP(req)
}
// IsAllowedRoute is used to check if the request method & path is allowed without auth
func (p *OAuthProxy) isAllowedRoute(req *http.Request) bool {
for _, route := range p.allowedRoutes {
if (route.method == "" || req.Method == route.method) && route.pathRegex.MatchString(req.URL.Path) {
return true
}
}
return false
}
// See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
var noCacheHeaders = map[string]string{
"Expires": time.Unix(0, 0).Format(time.RFC1123),
"Cache-Control": "no-cache, no-store, must-revalidate, max-age=0",
"X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
}
// prepareNoCache prepares headers for preventing browser caching.
func prepareNoCache(w http.ResponseWriter) {
// Set NoCache headers
for k, v := range noCacheHeaders {
w.Header().Set(k, v)
}
}
// IsTrustedIP is used to check if a request comes from a trusted client IP address.
func (p *OAuthProxy) IsTrustedIP(req *http.Request) bool {
if p.trustedIPs == nil {
return false
}
remoteAddr, err := ip.GetClientIP(p.realClientIPParser, req)
if err != nil {
logger.Errorf("Error obtaining real IP for trusted IP list: %v", err)
// Possibly spoofed X-Real-IP header
return false
}
if remoteAddr == nil {
return false
}
return p.trustedIPs.Has(remoteAddr)
}
func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
p.preAuthChain.Then(http.HandlerFunc(p.serveHTTP)).ServeHTTP(rw, req)
}
func (p *OAuthProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path != p.AuthOnlyPath && strings.HasPrefix(req.URL.Path, p.ProxyPrefix) {
prepareNoCache(rw)
}
switch path := req.URL.Path; {
case path == p.RobotsPath:
p.RobotsTxt(rw)
case p.IsAllowedRequest(req):
p.SkipAuthProxy(rw, req)
case path == p.SignInPath:
p.SignIn(rw, req)
case path == p.SignOutPath:
p.SignOut(rw, req)
case path == p.OAuthStartPath:
p.OAuthStart(rw, req)
case path == p.OAuthCallbackPath:
p.OAuthCallback(rw, req)
case path == p.AuthOnlyPath:
p.AuthenticateOnly(rw, req)
case path == p.UserInfoPath:
p.UserInfo(rw, req)
default:
p.Proxy(rw, req)
}
}
// SignIn serves a page prompting users to sign in
func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
redirect, err := p.GetRedirect(req)
redirect, err := p.getAppRedirect(req)
if err != nil {
logger.Errorf("Error obtaining redirect: %v", err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
@ -812,7 +692,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) {
// SignOut sends a response to clear the authentication cookie
func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
redirect, err := p.GetRedirect(req)
redirect, err := p.getAppRedirect(req)
if err != nil {
logger.Errorf("Error obtaining redirect: %v", err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
@ -837,13 +717,13 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
return
}
p.SetCSRFCookie(rw, req, nonce)
redirect, err := p.GetRedirect(req)
redirect, err := p.getAppRedirect(req)
if err != nil {
logger.Errorf("Error obtaining redirect: %v", err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
return
}
redirectURI := p.GetRedirectURI(util.GetRequestHost(req))
redirectURI := p.getOAuthRedirectURI(req)
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound)
}
@ -866,7 +746,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
return
}
session, err := p.redeemCode(req.Context(), util.GetRequestHost(req), req.Form.Get("code"))
session, err := p.redeemCode(req)
if err != nil {
logger.Errorf("Error redeeming code during OAuth2 callback: %v", err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error")
@ -925,16 +805,50 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
}
}
// AuthenticateOnly checks whether the user is currently logged in
func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) {
func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, error) {
code := req.Form.Get("code")
if code == "" {
return nil, providers.ErrMissingCode
}
redirectURI := p.getOAuthRedirectURI(req)
s, err := p.provider.Redeem(req.Context(), redirectURI, code)
if err != nil {
return nil, err
}
return s, nil
}
func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.SessionState) error {
var err error
if s.Email == "" {
s.Email, err = p.provider.GetEmailAddress(ctx, s)
if err != nil && !errors.Is(err, providers.ErrNotImplemented) {
return err
}
}
return p.provider.EnrichSession(ctx, s)
}
// AuthOnly checks whether the user is currently logged in (both authentication
// and optional authorization).
func (p *OAuthProxy) AuthOnly(rw http.ResponseWriter, req *http.Request) {
session, err := p.getAuthenticatedSession(rw, req)
if err != nil {
http.Error(rw, "unauthorized request", http.StatusUnauthorized)
http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
// Unauthorized cases need to return 403 to prevent infinite redirects with
// subrequest architectures
if !authOnlyAuthorize(req, session) {
http.Error(rw, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
// we are authenticated
p.addHeadersForProxying(rw, req, session)
p.addHeadersForProxying(rw, session)
p.headersChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusAccepted)
})).ServeHTTP(rw, req)
@ -952,13 +866,13 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
switch err {
case nil:
// we are authenticated
p.addHeadersForProxying(rw, req, session)
p.addHeadersForProxying(rw, session)
p.headersChain.Then(p.serveMux).ServeHTTP(rw, req)
case ErrNeedsLogin:
// we need to send the user to a login screen
if isAjax(req) {
// no point redirecting an AJAX request
p.ErrorJSON(rw, http.StatusUnauthorized)
p.errorJSON(rw, http.StatusUnauthorized)
return
}
@ -977,7 +891,195 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
p.ErrorPage(rw, http.StatusInternalServerError,
"Internal Error", "Internal Error")
}
}
// See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
var noCacheHeaders = map[string]string{
"Expires": time.Unix(0, 0).Format(time.RFC1123),
"Cache-Control": "no-cache, no-store, must-revalidate, max-age=0",
"X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
}
// prepareNoCache prepares headers for preventing browser caching.
func prepareNoCache(w http.ResponseWriter) {
// Set NoCache headers
for k, v := range noCacheHeaders {
w.Header().Set(k, v)
}
}
// getOAuthRedirectURI returns the redirectURL that the upstream OAuth Provider will
// redirect clients to once authenticated.
// This is usually the OAuthProxy callback URL.
func (p *OAuthProxy) getOAuthRedirectURI(req *http.Request) string {
// if `p.redirectURL` already has a host, return it
if p.redirectURL.Host != "" {
return p.redirectURL.String()
}
// Otherwise figure out the scheme + host from the request
rd := *p.redirectURL
rd.Host = requestutil.GetRequestHost(req)
rd.Scheme = requestutil.GetRequestProto(req)
// If CookieSecure is true, return `https` no matter what
// Not all reverse proxies set X-Forwarded-Proto
if p.CookieSecure {
rd.Scheme = schemeHTTPS
}
return rd.String()
}
// getAppRedirect determines the full URL or URI path to redirect clients to
// once authenticated with the OAuthProxy
// Strategy priority (first legal result is used):
// - `rd` querysting parameter
// - `X-Auth-Request-Redirect` header
// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
// - `/`
func (p *OAuthProxy) getAppRedirect(req *http.Request) (string, error) {
err := req.ParseForm()
if err != nil {
return "", err
}
// These redirect getter functions are strategies ordered by priority
// for figuring out the redirect URL.
type redirectGetter func(req *http.Request) string
for _, rdGetter := range []redirectGetter{
p.getRdQuerystringRedirect,
p.getXAuthRequestRedirect,
p.getXForwardedHeadersRedirect,
p.getURIRedirect,
} {
redirect := rdGetter(req)
// Call `p.IsValidRedirect` again here a final time to be safe
if redirect != "" && p.IsValidRedirect(redirect) {
return redirect, nil
}
}
return "/", nil
}
func isForwardedRequest(req *http.Request) bool {
return requestutil.IsProxied(req) &&
req.Host != requestutil.GetRequestHost(req)
}
func (p *OAuthProxy) hasProxyPrefix(path string) bool {
return strings.HasPrefix(path, fmt.Sprintf("%s/", p.ProxyPrefix))
}
func (p *OAuthProxy) validateRedirect(redirect string, errorFormat string) string {
if p.IsValidRedirect(redirect) {
return redirect
}
if redirect != "" {
logger.Errorf(errorFormat, redirect)
}
return ""
}
// getRdQuerystringRedirect handles this getAppRedirect strategy:
// - `rd` querysting parameter
func (p *OAuthProxy) getRdQuerystringRedirect(req *http.Request) string {
return p.validateRedirect(
req.Form.Get("rd"),
"Invalid redirect provided in rd querystring parameter: %s",
)
}
// getXAuthRequestRedirect handles this getAppRedirect strategy:
// - `X-Auth-Request-Redirect` Header
func (p *OAuthProxy) getXAuthRequestRedirect(req *http.Request) string {
return p.validateRedirect(
req.Header.Get("X-Auth-Request-Redirect"),
"Invalid redirect provided in X-Auth-Request-Redirect header: %s",
)
}
// getXForwardedHeadersRedirect handles these getAppRedirect strategies:
// - `X-Forwarded-(Proto|Host|Uri)` headers (when ReverseProxy mode is enabled)
// - `X-Forwarded-(Proto|Host)` if `Uri` has the ProxyPath (i.e. /oauth2/*)
func (p *OAuthProxy) getXForwardedHeadersRedirect(req *http.Request) string {
if !isForwardedRequest(req) {
return ""
}
uri := requestutil.GetRequestURI(req)
if p.hasProxyPrefix(uri) {
uri = "/"
}
redirect := fmt.Sprintf(
"%s://%s%s",
requestutil.GetRequestProto(req),
requestutil.GetRequestHost(req),
uri,
)
return p.validateRedirect(redirect,
"Invalid redirect generated from X-Forwarded-* headers: %s")
}
// getURIRedirect handles these getAppRedirect strategies:
// - `X-Forwarded-Uri` direct URI path (when ReverseProxy mode is enabled)
// - `req.URL.RequestURI` if not under the ProxyPath (i.e. /oauth2/*)
// - `/`
func (p *OAuthProxy) getURIRedirect(req *http.Request) string {
redirect := p.validateRedirect(
requestutil.GetRequestURI(req),
"Invalid redirect generated from X-Forwarded-Uri header: %s",
)
if redirect == "" {
redirect = req.URL.RequestURI()
}
if p.hasProxyPrefix(redirect) {
return "/"
}
return redirect
}
// splitHostPort separates host and port. If the port is not valid, it returns
// the entire input as host, and it doesn't check the validity of the host.
// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
// *** taken from net/url, modified validOptionalPort() to accept ":*"
func splitHostPort(hostport string) (host, port string) {
host = hostport
colon := strings.LastIndexByte(host, ':')
if colon != -1 && validOptionalPort(host[colon:]) {
host, port = host[:colon], host[colon+1:]
}
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
host = host[1 : len(host)-1]
}
return
}
// validOptionalPort reports whether port is either an empty string
// or matches /^:\d*$/
// *** taken from net/url, modified to accept ":*"
func validOptionalPort(port string) bool {
if port == "" || port == ":*" {
return true
}
if port[0] != ':' {
return false
}
for _, b := range port[1:] {
if b < '0' || b > '9' {
return false
}
}
return true
}
// getAuthenticatedSession checks whether a user is authenticated and returns a session object and nil error if so
@ -989,7 +1091,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
var session *sessionsapi.SessionState
getSession := p.sessionChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
session = middleware.GetRequestScope(req).Session
session = middlewareapi.GetRequestScope(req).Session
}))
getSession.ServeHTTP(rw, req)
@ -1016,8 +1118,55 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R
return session, nil
}
// authOnlyAuthorize handles special authorization logic that is only done
// on the AuthOnly endpoint for use with Nginx subrequest architectures.
//
// TODO (@NickMeves): This method is a placeholder to be extended but currently
// fails the linter. Remove the nolint when functionality expands.
//
//nolint:S1008
func authOnlyAuthorize(req *http.Request, s *sessionsapi.SessionState) bool {
// Allow secondary group restrictions based on the `allowed_groups`
// querystring parameter
if !checkAllowedGroups(req, s) {
return false
}
return true
}
func checkAllowedGroups(req *http.Request, s *sessionsapi.SessionState) bool {
allowedGroups := extractAllowedGroups(req)
if len(allowedGroups) == 0 {
return true
}
for _, group := range s.Groups {
if _, ok := allowedGroups[group]; ok {
return true
}
}
return false
}
func extractAllowedGroups(req *http.Request) map[string]struct{} {
groups := map[string]struct{}{}
query := req.URL.Query()
for _, allowedGroups := range query["allowed_groups"] {
for _, group := range strings.Split(allowedGroups, ",") {
if group != "" {
groups[group] = struct{}{}
}
}
}
return groups
}
// addHeadersForProxying adds the appropriate headers the request / response for proxying
func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) {
func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, session *sessionsapi.SessionState) {
if session.Email == "" {
rw.Header().Set("GAP-Auth", session.User)
} else {
@ -1029,16 +1178,24 @@ func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Req
func isAjax(req *http.Request) bool {
acceptValues := req.Header.Values("Accept")
const ajaxReq = applicationJSON
for _, v := range acceptValues {
if v == ajaxReq {
// Iterate over multiple Accept headers, i.e.
// Accept: application/json
// Accept: text/plain
for _, mimeTypes := range acceptValues {
// Iterate over multiple mimetypes in a single header, i.e.
// Accept: application/json, text/plain, */*
for _, mimeType := range strings.Split(mimeTypes, ",") {
mimeType = strings.TrimSpace(mimeType)
if mimeType == ajaxReq {
return true
}
}
}
return false
}
// ErrorJSON returns the error code with an application/json mime type
func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) {
// errorJSON returns the error code with an application/json mime type
func (p *OAuthProxy) errorJSON(rw http.ResponseWriter, code int) {
rw.Header().Set("Content-Type", applicationJSON)
rw.WriteHeader(code)
}

View File

@ -19,6 +19,7 @@ import (
"github.com/coreos/go-oidc"
"github.com/mbland/hmacauth"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
@ -414,8 +415,9 @@ func Test_redeemCode(t *testing.T) {
t.Fatal(err)
}
_, err = proxy.redeemCode(context.Background(), "www.example.com", "")
assert.Error(t, err)
req := httptest.NewRequest(http.MethodGet, "/", nil)
_, err = proxy.redeemCode(req)
assert.Equal(t, providers.ErrMissingCode, err)
}
func Test_enrichSession(t *testing.T) {
@ -1197,18 +1199,20 @@ func TestUserInfoEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
}
func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
func NewAuthOnlyEndpointTest(querystring string, modifiers ...OptionsModifier) (*ProcessCookieTest, error) {
pcTest, err := NewProcessCookieTestWithOptionsModifiers(modifiers...)
if err != nil {
return nil, err
}
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)
pcTest.req, _ = http.NewRequest(
"GET",
fmt.Sprintf("%s/auth%s", pcTest.opts.ProxyPrefix, querystring),
nil)
return pcTest, nil
}
func TestAuthOnlyEndpointAccepted(t *testing.T) {
test, err := NewAuthOnlyEndpointTest()
test, err := NewAuthOnlyEndpointTest("")
if err != nil {
t.Fatal(err)
}
@ -1226,7 +1230,7 @@ func TestAuthOnlyEndpointAccepted(t *testing.T) {
}
func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
test, err := NewAuthOnlyEndpointTest()
test, err := NewAuthOnlyEndpointTest("")
if err != nil {
t.Fatal(err)
}
@ -1234,11 +1238,11 @@ func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
}
func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) {
test, err := NewAuthOnlyEndpointTest("", func(opts *options.Options) {
opts.Cookie.Expire = time.Duration(24) * time.Hour
})
if err != nil {
@ -1254,11 +1258,11 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
}
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
test, err := NewAuthOnlyEndpointTest()
test, err := NewAuthOnlyEndpointTest("")
if err != nil {
t.Fatal(err)
}
@ -1273,7 +1277,7 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
assert.Equal(t, "Unauthorized\n", string(bodyBytes))
}
func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
@ -1746,8 +1750,9 @@ func TestRequestSignature(t *testing.T) {
}
}
func TestGetRedirect(t *testing.T) {
func Test_getAppRedirect(t *testing.T) {
opts := baseTestOptions()
opts.WhitelistDomains = append(opts.WhitelistDomains, ".example.com", ".example.com:8443")
err := validation.Validate(opts)
assert.NoError(t, err)
require.NotEmpty(t, opts.ProxyPrefix)
@ -1759,28 +1764,144 @@ func TestGetRedirect(t *testing.T) {
tests := []struct {
name string
url string
headers map[string]string
reverseProxy bool
expectedRedirect string
}{
{
name: "request outside of ProxyPrefix redirects to original URL",
url: "/foo/bar",
headers: nil,
reverseProxy: false,
expectedRedirect: "/foo/bar",
},
{
name: "request with query preserves query",
url: "/foo?bar",
headers: nil,
reverseProxy: false,
expectedRedirect: "/foo?bar",
},
{
name: "request under ProxyPrefix redirects to root",
url: proxy.ProxyPrefix + "/foo/bar",
headers: nil,
reverseProxy: false,
expectedRedirect: "/",
},
{
name: "proxied request outside of ProxyPrefix redirects to proxied URL",
url: "https://oauth.example.com/foo/bar",
headers: map[string]string{
"X-Forwarded-Proto": "https",
"X-Forwarded-Host": "a-service.example.com",
"X-Forwarded-Uri": "/foo/bar",
},
reverseProxy: true,
expectedRedirect: "https://a-service.example.com/foo/bar",
},
{
name: "non-proxied request with spoofed proxy headers wouldn't redirect",
url: "https://oauth.example.com/foo?bar",
headers: map[string]string{
"X-Forwarded-Proto": "https",
"X-Forwarded-Host": "a-service.example.com",
"X-Forwarded-Uri": "/foo/bar",
},
reverseProxy: false,
expectedRedirect: "/foo?bar",
},
{
name: "proxied request under ProxyPrefix redirects to root",
url: "https://oauth.example.com" + proxy.ProxyPrefix + "/foo/bar",
headers: map[string]string{
"X-Forwarded-Proto": "https",
"X-Forwarded-Host": "a-service.example.com",
"X-Forwarded-Uri": proxy.ProxyPrefix + "/foo/bar",
},
reverseProxy: true,
expectedRedirect: "https://a-service.example.com/",
},
{
name: "proxied request with port under ProxyPrefix redirects to root",
url: "https://oauth.example.com" + proxy.ProxyPrefix + "/foo/bar",
headers: map[string]string{
"X-Forwarded-Proto": "https",
"X-Forwarded-Host": "a-service.example.com:8443",
"X-Forwarded-Uri": proxy.ProxyPrefix + "/foo/bar",
},
reverseProxy: true,
expectedRedirect: "https://a-service.example.com:8443/",
},
{
name: "proxied request with missing uri header would still redirect to desired redirect",
url: "https://oauth.example.com/foo?bar",
headers: map[string]string{
"X-Forwarded-Proto": "https",
"X-Forwarded-Host": "a-service.example.com",
},
reverseProxy: true,
expectedRedirect: "https://a-service.example.com/foo?bar",
},
{
name: "request with headers proxy not being set (and reverse proxy enabled) would still redirect to desired redirect",
url: "https://oauth.example.com/foo?bar",
headers: nil,
reverseProxy: true,
expectedRedirect: "/foo?bar",
},
{
name: "proxied request with X-Auth-Request-Redirect being set outside of ProxyPrefix redirects to proxied URL",
url: "https://oauth.example.com/foo/bar",
headers: map[string]string{
"X-Auth-Request-Redirect": "https://a-service.example.com/foo/bar",
},
reverseProxy: true,
expectedRedirect: "https://a-service.example.com/foo/bar",
},
{
name: "proxied request with rd query string redirects to proxied URL",
url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fbar",
headers: nil,
reverseProxy: false,
expectedRedirect: "https://a-service.example.com/foo/bar",
},
{
name: "proxied request with rd query string and all headers set (and reverse proxy not enabled) redirects to proxied URL on rd query string",
url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fjazz",
headers: map[string]string{
"X-Auth-Request-Redirect": "https://a-service.example.com/foo/baz",
"X-Forwarded-Proto": "http",
"X-Forwarded-Host": "another-service.example.com",
"X-Forwarded-Uri": "/seasons/greetings",
},
reverseProxy: false,
expectedRedirect: "https://a-service.example.com/foo/jazz",
},
{
name: "proxied request with rd query string and some headers set redirects to proxied URL on rd query string",
url: "https://oauth.example.com/foo/bar?rd=https%3A%2F%2Fa%2Dservice%2Eexample%2Ecom%2Ffoo%2Fbaz",
headers: map[string]string{
"X-Forwarded-Proto": "https",
"X-Forwarded-Host": "another-service.example.com",
"X-Forwarded-Uri": "/seasons/greetings",
},
reverseProxy: true,
expectedRedirect: "https://a-service.example.com/foo/baz",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, _ := http.NewRequest("GET", tt.url, nil)
redirect, err := proxy.GetRedirect(req)
for header, value := range tt.headers {
if value != "" {
req.Header.Add(header, value)
}
}
req = middleware.AddRequestScope(req, &middleware.RequestScope{
ReverseProxy: tt.reverseProxy,
})
redirect, err := proxy.getAppRedirect(req)
assert.NoError(t, err)
assert.Equal(t, tt.expectedRedirect, redirect)
@ -1848,6 +1969,13 @@ func TestAjaxUnauthorizedRequest2(t *testing.T) {
testAjaxUnauthorizedRequest(t, header)
}
func TestAjaxUnauthorizedRequestAccept1(t *testing.T) {
header := make(http.Header)
header.Add("Accept", "application/json, text/plain, */*")
testAjaxUnauthorizedRequest(t, header)
}
func TestAjaxForbiddendRequest(t *testing.T) {
test, err := newAjaxRequestTest()
if err != nil {
@ -1960,7 +2088,7 @@ func TestGetJwtSession(t *testing.T) {
verifier := oidc.NewVerifier("https://issuer.example.com", keyset,
&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true})
test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) {
test, err := NewAuthOnlyEndpointTest("", func(opts *options.Options) {
opts.InjectRequestHeaders = []options.Header{
{
Name: "Authorization",
@ -2028,7 +2156,6 @@ func TestGetJwtSession(t *testing.T) {
},
},
}
opts.SkipJwtBearerTokens = true
opts.SetJWTBearerVerifiers(append(opts.GetJWTBearerVerifiers(), verifier))
})
@ -2692,32 +2819,106 @@ func TestProxyAllowedGroups(t *testing.T) {
}
func TestAuthOnlyAllowedGroups(t *testing.T) {
tests := []struct {
testCases := []struct {
name string
allowedGroups []string
groups []string
expectUnauthorized bool
querystring string
expectedStatusCode int
}{
{"NoAllowedGroups", []string{}, []string{}, false},
{"NoAllowedGroupsUserHasGroups", []string{}, []string{"a", "b"}, false},
{"UserInAllowedGroup", []string{"a"}, []string{"a", "b"}, false},
{"UserNotInAllowedGroup", []string{"a"}, []string{"c"}, true},
{
name: "NoAllowedGroups",
allowedGroups: []string{},
groups: []string{},
querystring: "",
expectedStatusCode: http.StatusAccepted,
},
{
name: "NoAllowedGroupsUserHasGroups",
allowedGroups: []string{},
groups: []string{"a", "b"},
querystring: "",
expectedStatusCode: http.StatusAccepted,
},
{
name: "UserInAllowedGroup",
allowedGroups: []string{"a"},
groups: []string{"a", "b"},
querystring: "",
expectedStatusCode: http.StatusAccepted,
},
{
name: "UserNotInAllowedGroup",
allowedGroups: []string{"a"},
groups: []string{"c"},
querystring: "",
expectedStatusCode: http.StatusUnauthorized,
},
{
name: "UserInQuerystringGroup",
allowedGroups: []string{"a", "b"},
groups: []string{"a", "c"},
querystring: "?allowed_groups=a",
expectedStatusCode: http.StatusAccepted,
},
{
name: "UserInMultiParamQuerystringGroup",
allowedGroups: []string{"a", "b"},
groups: []string{"b"},
querystring: "?allowed_groups=a&allowed_groups=b,d",
expectedStatusCode: http.StatusAccepted,
},
{
name: "UserInOnlyQuerystringGroup",
allowedGroups: []string{},
groups: []string{"a", "c"},
querystring: "?allowed_groups=a,b",
expectedStatusCode: http.StatusAccepted,
},
{
name: "UserInDelimitedQuerystringGroup",
allowedGroups: []string{"a", "b", "c"},
groups: []string{"c"},
querystring: "?allowed_groups=a,c",
expectedStatusCode: http.StatusAccepted,
},
{
name: "UserNotInQuerystringGroup",
allowedGroups: []string{},
groups: []string{"c"},
querystring: "?allowed_groups=a,b",
expectedStatusCode: http.StatusForbidden,
},
{
name: "UserInConfigGroupNotInQuerystringGroup",
allowedGroups: []string{"a", "b", "c"},
groups: []string{"c"},
querystring: "?allowed_groups=a,b",
expectedStatusCode: http.StatusForbidden,
},
{
name: "UserInQuerystringGroupNotInConfigGroup",
allowedGroups: []string{"a", "b"},
groups: []string{"c"},
querystring: "?allowed_groups=b,c",
expectedStatusCode: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
emailAddress := "test"
created := time.Now()
session := &sessions.SessionState{
Groups: tt.groups,
Groups: tc.groups,
Email: emailAddress,
AccessToken: "oauth_token",
CreatedAt: &created,
}
test, err := NewAuthOnlyEndpointTest(func(opts *options.Options) {
opts.AllowedGroups = tt.allowedGroups
test, err := NewAuthOnlyEndpointTest(tc.querystring, func(opts *options.Options) {
opts.AllowedGroups = tc.allowedGroups
})
if err != nil {
t.Fatal(err)
@ -2728,11 +2929,7 @@ func TestAuthOnlyAllowedGroups(t *testing.T) {
test.proxy.ServeHTTP(test.rw, test.req)
if tt.expectUnauthorized {
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
} else {
assert.Equal(t, http.StatusAccepted, test.rw.Code)
}
assert.Equal(t, tc.expectedStatusCode, test.rw.Code)
})
}
}

View File

@ -0,0 +1,19 @@
package middleware_test
import (
"testing"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
// TestMiddlewareSuite and related tests are in a *_test package
// to prevent circular imports with the `logger` package which uses
// this functionality
func TestMiddlewareSuite(t *testing.T) {
logger.SetOutput(GinkgoWriter)
RegisterFailHandler(Fail)
RunSpecs(t, "Middleware API")
}

View File

@ -1,13 +1,26 @@
package middleware
import (
"context"
"net/http"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
)
type scopeKey string
// RequestScopeKey uses a typed string to reduce likelihood of clashing
// with other context keys
const RequestScopeKey scopeKey = "request-scope"
// RequestScope contains information regarding the request that is being made.
// The RequestScope is used to pass information between different middlewares
// within the chain.
type RequestScope struct {
// ReverseProxy tracks whether OAuth2-Proxy is operating in reverse proxy
// mode and if request `X-Forwarded-*` headers should be trusted
ReverseProxy bool
// Session details the authenticated users information (if it exists).
Session *sessions.SessionState
@ -22,3 +35,19 @@ type RequestScope struct {
// it was loaded or not.
SessionRevalidated bool
}
// GetRequestScope returns the current request scope from the given request
func GetRequestScope(req *http.Request) *RequestScope {
scope := req.Context().Value(RequestScopeKey)
if scope == nil {
return nil
}
return scope.(*RequestScope)
}
// AddRequestScope adds a RequestScope to a request
func AddRequestScope(req *http.Request, scope *RequestScope) *http.Request {
ctx := context.WithValue(req.Context(), RequestScopeKey, scope)
return req.WithContext(ctx)
}

View File

@ -0,0 +1,56 @@
package middleware_test
import (
"net/http"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Scope Suite", func() {
Context("GetRequestScope", func() {
var request *http.Request
BeforeEach(func() {
var err error
request, err = http.NewRequest("", "http://127.0.0.1/", nil)
Expect(err).ToNot(HaveOccurred())
})
Context("with a scope", func() {
var scope *middleware.RequestScope
BeforeEach(func() {
scope = &middleware.RequestScope{}
request = middleware.AddRequestScope(request, scope)
})
It("returns the scope", func() {
s := middleware.GetRequestScope(request)
Expect(s).ToNot(BeNil())
Expect(s).To(Equal(scope))
})
Context("if the scope is then modified", func() {
BeforeEach(func() {
Expect(scope.SaveSession).To(BeFalse())
scope.SaveSession = true
})
It("returns the updated session", func() {
s := middleware.GetRequestScope(request)
Expect(s).ToNot(BeNil())
Expect(s).To(Equal(scope))
Expect(s.SaveSession).To(BeTrue())
})
})
})
Context("without a scope", func() {
It("returns nil", func() {
Expect(middleware.GetRequestScope(request)).To(BeNil())
})
})
})
})

View File

@ -36,7 +36,7 @@ type Options struct {
TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file"`
AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"`
KeycloakGroup string `flag:"keycloak-group" cfg:"keycloak_group"`
KeycloakGroups []string `flag:"keycloak-group" cfg:"keycloak_groups"`
AzureTenant string `flag:"azure-tenant" cfg:"azure_tenant"`
BitbucketTeam string `flag:"bitbucket-team" cfg:"bitbucket_team"`
BitbucketRepository string `flag:"bitbucket-repository" cfg:"bitbucket_repository"`
@ -87,6 +87,7 @@ type Options struct {
InsecureOIDCSkipIssuerVerification bool `flag:"insecure-oidc-skip-issuer-verification" cfg:"insecure_oidc_skip_issuer_verification"`
SkipOIDCDiscovery bool `flag:"skip-oidc-discovery" cfg:"skip_oidc_discovery"`
OIDCJwksURL string `flag:"oidc-jwks-url" cfg:"oidc_jwks_url"`
OIDCEmailClaim string `flag:"oidc-email-claim" cfg:"oidc_email_claim"`
OIDCGroupsClaim string `flag:"oidc-groups-claim" cfg:"oidc_groups_claim"`
LoginURL string `flag:"login-url" cfg:"login_url"`
RedeemURL string `flag:"redeem-url" cfg:"redeem_url"`
@ -148,11 +149,12 @@ func NewOptions() *Options {
SkipAuthPreflight: false,
Prompt: "", // Change to "login" when ApprovalPrompt officially deprecated
ApprovalPrompt: "force",
UserIDClaim: "email",
InsecureOIDCAllowUnverifiedEmail: false,
SkipOIDCDiscovery: false,
Logging: loggingDefaults(),
OIDCGroupsClaim: "groups",
UserIDClaim: providers.OIDCEmailClaim, // Deprecated: Use OIDCEmailClaim
OIDCEmailClaim: providers.OIDCEmailClaim,
OIDCGroupsClaim: providers.OIDCGroupsClaim,
}
}
@ -179,7 +181,7 @@ func NewFlagSet() *pflag.FlagSet {
flagSet.StringSlice("email-domain", []string{}, "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email")
flagSet.StringSlice("whitelist-domain", []string{}, "allowed domains for redirection after authentication. Prefix domain with a . to allow subdomains (eg .example.com)")
flagSet.String("keycloak-group", "", "restrict login to members of this group.")
flagSet.StringSlice("keycloak-group", []string{}, "restrict logins to members of these groups (may be given multiple times)")
flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.")
flagSet.String("bitbucket-team", "", "restrict logins to members of this team")
flagSet.String("bitbucket-repository", "", "restrict logins to user with access to this repository")
@ -226,7 +228,8 @@ func NewFlagSet() *pflag.FlagSet {
flagSet.Bool("insecure-oidc-skip-issuer-verification", false, "Do not verify if issuer matches OIDC discovery URL")
flagSet.Bool("skip-oidc-discovery", false, "Skip OIDC discovery and use manually supplied Endpoints")
flagSet.String("oidc-jwks-url", "", "OpenID Connect JWKS URL (ie: https://www.googleapis.com/oauth2/v3/certs)")
flagSet.String("oidc-groups-claim", "groups", "which claim contains the user groups")
flagSet.String("oidc-groups-claim", providers.OIDCGroupsClaim, "which OIDC claim contains the user groups")
flagSet.String("oidc-email-claim", providers.OIDCEmailClaim, "which OIDC claim contains the user's email")
flagSet.String("login-url", "", "Authentication endpoint")
flagSet.String("redeem-url", "", "Token redemption endpoint")
flagSet.String("profile-url", "", "Profile access endpoint")
@ -243,7 +246,7 @@ func NewFlagSet() *pflag.FlagSet {
flagSet.String("pubjwk-url", "", "JWK pubkey access endpoint: required by login.gov")
flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints")
flagSet.String("user-id-claim", "email", "which claim contains the user ID")
flagSet.String("user-id-claim", providers.OIDCEmailClaim, "(DEPRECATED for `oidc-email-claim`) which claim contains the user ID")
flagSet.StringSlice("allowed-group", []string{}, "restrict logins to members of this group (may be given multiple times)")
flagSet.AddFlagSet(cookieFlagSet())

View File

@ -9,14 +9,14 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
)
// MakeCookie constructs a cookie from the given parameters,
// discovering the domain from the request if not specified.
func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie {
if domain != "" {
host := util.GetRequestHost(req)
host := requestutil.GetRequestHost(req)
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
@ -48,7 +48,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO
// If nothing matches, create the cookie with the shortest domain
defaultDomain := ""
if len(cookieOpts.Domains) > 0 {
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", util.GetRequestHost(req), strings.Join(cookieOpts.Domains, ","))
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", requestutil.GetRequestHost(req), strings.Join(cookieOpts.Domains, ","))
defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1]
}
return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite))
@ -57,7 +57,7 @@ func MakeCookieFromOptions(req *http.Request, name string, value string, cookieO
// GetCookieDomain returns the correct cookie domain given a list of domains
// by checking the X-Fowarded-Host and host header of an an http request
func GetCookieDomain(req *http.Request, cookieDomains []string) string {
host := util.GetRequestHost(req)
host := requestutil.GetRequestHost(req)
for _, domain := range cookieDomains {
if strings.HasSuffix(host, domain) {
return domain

View File

@ -12,7 +12,7 @@ import (
"text/template"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
)
// AuthStatus defines the different types of auth logging that occur
@ -197,7 +197,7 @@ func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatu
err := l.authTemplate.Execute(l.writer, authLogMessageData{
Client: client,
Host: util.GetRequestHost(req),
Host: requestutil.GetRequestHost(req),
Protocol: req.Proto,
RequestMethod: req.Method,
Timestamp: FormatTimestamp(now),
@ -251,7 +251,7 @@ func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url.
err := l.reqTemplate.Execute(l.writer, reqLogMessageData{
Client: client,
Host: util.GetRequestHost(req),
Host: requestutil.GetRequestHost(req),
Protocol: req.Proto,
RequestDuration: fmt.Sprintf("%0.3f", duration),
RequestMethod: req.Method,

View File

@ -5,6 +5,7 @@ import (
"net/http"
"github.com/justinas/alice"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
@ -23,7 +24,7 @@ func NewBasicAuthSessionLoader(validator basic.Validator) alice.Constructor {
// If a session was loaded by a previous handler, it will not be replaced.
func loadBasicAuthSession(validator basic.Validator, next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope := GetRequestScope(req)
scope := middlewareapi.GetRequestScope(req)
// If scope is nil, this will panic.
// A scope should always be injected before this handler is called.
if scope.Session != nil {

View File

@ -1,7 +1,6 @@
package middleware
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
@ -40,8 +39,7 @@ var _ = Describe("Basic Auth Session Suite", func() {
// Set up the request with the authorization header and a request scope
req := httptest.NewRequest("", "/", nil)
req.Header.Set("Authorization", in.authorizationHeader)
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
req = req.WithContext(contextWithScope)
req = middlewareapi.AddRequestScope(req, scope)
rw := httptest.NewRecorder()
@ -57,7 +55,7 @@ var _ = Describe("Basic Auth Session Suite", func() {
// from the scope
var gotSession *sessionsapi.SessionState
handler := NewBasicAuthSessionLoader(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session
gotSession = middlewareapi.GetRequestScope(r).Session
}))
handler.ServeHTTP(rw, req)

View File

@ -5,6 +5,7 @@ import (
"net/http"
"github.com/justinas/alice"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/header"
)
@ -61,7 +62,7 @@ func newRequestHeaderInjector(headers []options.Header) (alice.Constructor, erro
func injectRequestHeaders(injector header.Injector, next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope := GetRequestScope(req)
scope := middlewareapi.GetRequestScope(req)
// If scope is nil, this will panic.
// A scope should always be injected before this handler is called.
@ -92,7 +93,7 @@ func newResponseHeaderInjector(headers []options.Header) (alice.Constructor, err
func injectResponseHeaders(injector header.Injector, next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope := GetRequestScope(req)
scope := middlewareapi.GetRequestScope(req)
// If scope is nil, this will panic.
// A scope should always be injected before this handler is called.

View File

@ -1,7 +1,6 @@
package middleware
import (
"context"
"encoding/base64"
"net/http"
"net/http/httptest"
@ -31,8 +30,7 @@ var _ = Describe("Headers Suite", func() {
// Set up the request with a request scope
req := httptest.NewRequest("", "/", nil)
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
req = req.WithContext(contextWithScope)
req = middlewareapi.AddRequestScope(req, scope)
req.Header = in.initialHeaders.Clone()
rw := httptest.NewRecorder()
@ -218,8 +216,7 @@ var _ = Describe("Headers Suite", func() {
// Set up the request with a request scope
req := httptest.NewRequest("", "/", nil)
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
req = req.WithContext(contextWithScope)
req = middlewareapi.AddRequestScope(req, scope)
rw := httptest.NewRecorder()
for key, values := range in.initialHeaders {

View File

@ -37,7 +37,7 @@ type jwtSessionLoader struct {
// If a session was loaded by a previous handler, it will not be replaced.
func (j *jwtSessionLoader) loadSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope := GetRequestScope(req)
scope := middlewareapi.GetRequestScope(req)
// If scope is nil, this will panic.
// A scope should always be injected before this handler is called.
if scope.Session != nil {

View File

@ -103,8 +103,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
// Set up the request with the authorization header and a request scope
req := httptest.NewRequest("", "/", nil)
req.Header.Set("Authorization", in.authorizationHeader)
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
req = req.WithContext(contextWithScope)
req = middlewareapi.AddRequestScope(req, scope)
rw := httptest.NewRecorder()
@ -116,7 +115,7 @@ Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=`
// from the scope
var gotSession *sessionsapi.SessionState
handler := NewJwtSessionLoader(sessionLoaders)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session
gotSession = middlewareapi.GetRequestScope(r).Session
}))
handler.ServeHTTP(rw, req)

View File

@ -7,7 +7,7 @@ import (
"strings"
"github.com/justinas/alice"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
)
const httpsScheme = "https"
@ -26,10 +26,11 @@ func NewRedirectToHTTPS(httpsPort string) alice.Constructor {
// to the port from the httpsAddress given.
func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
proto := req.Header.Get("X-Forwarded-Proto")
if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == "") {
// Only care about the connection to us being HTTPS if the proto is empty,
// otherwise the proto is source of truth
proto := requestutil.GetRequestProto(req)
if strings.EqualFold(proto, httpsScheme) || (req.TLS != nil && proto == req.URL.Scheme) {
// Only care about the connection to us being HTTPS if the proto wasn't
// from a trusted `X-Forwarded-Proto` (proto == req.URL.Scheme).
// Otherwise the proto is source of truth
next.ServeHTTP(rw, req)
return
}
@ -41,7 +42,7 @@ func redirectToHTTPS(httpsPort string, next http.Handler) http.Handler {
// Set the Host in case the targetURL still does not have one
// or it isn't X-Forwarded-Host aware
targetURL.Host = util.GetRequestHost(req)
targetURL.Host = requestutil.GetRequestHost(req)
// Overwrite the port if the original request was to a non-standard port
if targetURL.Port() != "" {

View File

@ -5,6 +5,7 @@ import (
"fmt"
"net/http/httptest"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
@ -21,6 +22,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
requestString string
useTLS bool
headers map[string]string
reverseProxy bool
expectedStatus int
expectedBody string
expectedLocation string
@ -35,6 +37,10 @@ var _ = Describe("RedirectToHTTPS suite", func() {
if in.useTLS {
req.TLS = &tls.ConnectionState{}
}
scope := &middlewareapi.RequestScope{
ReverseProxy: in.reverseProxy,
}
req = middlewareapi.AddRequestScope(req, scope)
rw := httptest.NewRecorder()
@ -52,6 +58,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
requestString: "http://example.com",
useTLS: false,
headers: map[string]string{},
reverseProxy: false,
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com"),
expectedLocation: "https://example.com",
@ -60,6 +67,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
requestString: "https://example.com",
useTLS: true,
headers: map[string]string{},
reverseProxy: false,
expectedStatus: 200,
expectedBody: "test",
}),
@ -69,15 +77,28 @@ var _ = Describe("RedirectToHTTPS suite", func() {
headers: map[string]string{
"X-Forwarded-Proto": "HTTPS",
},
reverseProxy: true,
expectedStatus: 200,
expectedBody: "test",
}),
Entry("without TLS and X-Forwarded-Proto=HTTPS but ReverseProxy not set", &requestTableInput{
requestString: "http://example.com",
useTLS: false,
headers: map[string]string{
"X-Forwarded-Proto": "HTTPS",
},
reverseProxy: false,
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com"),
expectedLocation: "https://example.com",
}),
Entry("with TLS and X-Forwarded-Proto=HTTPS", &requestTableInput{
requestString: "https://example.com",
useTLS: true,
headers: map[string]string{
"X-Forwarded-Proto": "HTTPS",
},
reverseProxy: true,
expectedStatus: 200,
expectedBody: "test",
}),
@ -87,6 +108,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
headers: map[string]string{
"X-Forwarded-Proto": "https",
},
reverseProxy: true,
expectedStatus: 200,
expectedBody: "test",
}),
@ -96,6 +118,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
headers: map[string]string{
"X-Forwarded-Proto": "https",
},
reverseProxy: true,
expectedStatus: 200,
expectedBody: "test",
}),
@ -105,6 +128,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
headers: map[string]string{
"X-Forwarded-Proto": "HTTP",
},
reverseProxy: true,
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com"),
expectedLocation: "https://example.com",
@ -115,6 +139,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
headers: map[string]string{
"X-Forwarded-Proto": "HTTP",
},
reverseProxy: true,
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com"),
expectedLocation: "https://example.com",
@ -125,6 +150,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
headers: map[string]string{
"X-Forwarded-Proto": "http",
},
reverseProxy: true,
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com"),
expectedLocation: "https://example.com",
@ -135,6 +161,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
headers: map[string]string{
"X-Forwarded-Proto": "http",
},
reverseProxy: true,
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com"),
expectedLocation: "https://example.com",
@ -143,6 +170,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
requestString: "http://example.com:8080",
useTLS: false,
headers: map[string]string{},
reverseProxy: false,
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://example.com:8443"),
expectedLocation: "https://example.com:8443",
@ -151,6 +179,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
requestString: "https://example.com:8443",
useTLS: true,
headers: map[string]string{},
reverseProxy: false,
expectedStatus: 200,
expectedBody: "test",
}),
@ -161,6 +190,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
requestString: "/",
useTLS: false,
expectedStatus: 308,
reverseProxy: false,
expectedBody: permanentRedirectBody("https://example.com/"),
expectedLocation: "https://example.com/",
}),
@ -171,6 +201,7 @@ var _ = Describe("RedirectToHTTPS suite", func() {
"X-Forwarded-Proto": "HTTP",
"X-Forwarded-Host": "external.example.com",
},
reverseProxy: true,
expectedStatus: 308,
expectedBody: permanentRedirectBody("https://external.example.com"),
expectedLocation: "https://external.example.com",

View File

@ -1,39 +1,20 @@
package middleware
import (
"context"
"net/http"
"github.com/justinas/alice"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
)
type scopeKey string
// requestScopeKey uses a typed string to reduce likelihood of clasing
// with other context keys
const requestScopeKey scopeKey = "request-scope"
func NewScope() alice.Constructor {
return addScope
}
// addScope injects a new request scope into the request context.
func addScope(next http.Handler) http.Handler {
func NewScope(reverseProxy bool) alice.Constructor {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope := &middlewareapi.RequestScope{}
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
requestWithScope := req.WithContext(contextWithScope)
next.ServeHTTP(rw, requestWithScope)
})
}
// GetRequestScope returns the current request scope from the given request
func GetRequestScope(req *http.Request) *middlewareapi.RequestScope {
scope := req.Context().Value(requestScopeKey)
if scope == nil {
return nil
scope := &middlewareapi.RequestScope{
ReverseProxy: reverseProxy,
}
req = middlewareapi.AddRequestScope(req, scope)
next.ServeHTTP(rw, req)
})
}
return scope.(*middlewareapi.RequestScope)
}

View File

@ -1,7 +1,6 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
@ -21,8 +20,11 @@ var _ = Describe("Scope Suite", func() {
Expect(err).ToNot(HaveOccurred())
rw = httptest.NewRecorder()
})
handler := NewScope()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Context("ReverseProxy is false", func() {
BeforeEach(func() {
handler := NewScope(false)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextRequest = r
w.WriteHeader(200)
}))
@ -30,64 +32,37 @@ var _ = Describe("Scope Suite", func() {
})
It("does not add a scope to the original request", func() {
Expect(request.Context().Value(requestScopeKey)).To(BeNil())
Expect(request.Context().Value(middlewareapi.RequestScopeKey)).To(BeNil())
})
It("cannot load a scope from the original request using GetRequestScope", func() {
Expect(GetRequestScope(request)).To(BeNil())
Expect(middlewareapi.GetRequestScope(request)).To(BeNil())
})
It("adds a scope to the request for the next handler", func() {
Expect(nextRequest.Context().Value(requestScopeKey)).ToNot(BeNil())
Expect(nextRequest.Context().Value(middlewareapi.RequestScopeKey)).ToNot(BeNil())
})
It("can load a scope from the next handler's request using GetRequestScope", func() {
Expect(GetRequestScope(nextRequest)).ToNot(BeNil())
scope := middlewareapi.GetRequestScope(nextRequest)
Expect(scope).ToNot(BeNil())
Expect(scope.ReverseProxy).To(BeFalse())
})
})
Context("GetRequestScope", func() {
var request *http.Request
Context("ReverseProxy is true", func() {
BeforeEach(func() {
var err error
request, err = http.NewRequest("", "http://127.0.0.1/", nil)
Expect(err).ToNot(HaveOccurred())
handler := NewScope(true)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextRequest = r
w.WriteHeader(200)
}))
handler.ServeHTTP(rw, request)
})
Context("with a scope", func() {
var scope *middlewareapi.RequestScope
BeforeEach(func() {
scope = &middlewareapi.RequestScope{}
contextWithScope := context.WithValue(request.Context(), requestScopeKey, scope)
request = request.WithContext(contextWithScope)
})
It("returns the scope", func() {
s := GetRequestScope(request)
Expect(s).ToNot(BeNil())
Expect(s).To(Equal(scope))
})
Context("if the scope is then modified", func() {
BeforeEach(func() {
Expect(scope.SaveSession).To(BeFalse())
scope.SaveSession = true
})
It("returns the updated session", func() {
s := GetRequestScope(request)
Expect(s).ToNot(BeNil())
Expect(s).To(Equal(scope))
Expect(s.SaveSession).To(BeTrue())
})
})
})
Context("without a scope", func() {
It("returns nil", func() {
Expect(GetRequestScope(request)).To(BeNil())
It("return a scope where the ReverseProxy field is true", func() {
scope := middlewareapi.GetRequestScope(nextRequest)
Expect(scope).ToNot(BeNil())
Expect(scope.ReverseProxy).To(BeTrue())
})
})
})

View File

@ -8,6 +8,7 @@ import (
"time"
"github.com/justinas/alice"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
)
@ -59,7 +60,7 @@ type storedSessionLoader struct {
// If a session was loader by a previous handler, it will not be replaced.
func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
scope := GetRequestScope(req)
scope := middlewareapi.GetRequestScope(req)
// If scope is nil, this will panic.
// A scope should always be injected before this handler is called.
if scope.Session != nil {

View File

@ -104,8 +104,7 @@ var _ = Describe("Stored Session Suite", func() {
// Set up the request with the request headesr and a request scope
req := httptest.NewRequest("", "/", nil)
req.Header = in.requestHeaders
contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope)
req = req.WithContext(contextWithScope)
req = middlewareapi.AddRequestScope(req, scope)
rw := httptest.NewRecorder()
@ -120,7 +119,7 @@ var _ = Describe("Stored Session Suite", func() {
// from the scope
var gotSession *sessionsapi.SessionState
handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session
gotSession = middlewareapi.GetRequestScope(r).Session
}))
handler.ServeHTTP(rw, req)

48
pkg/requests/util/util.go Normal file
View File

@ -0,0 +1,48 @@
package util
import (
"net/http"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
)
// GetRequestProto returns the request scheme or X-Forwarded-Proto if present
// and the request is proxied.
func GetRequestProto(req *http.Request) string {
proto := req.Header.Get("X-Forwarded-Proto")
if !IsProxied(req) || proto == "" {
proto = req.URL.Scheme
}
return proto
}
// GetRequestHost returns the request host header or X-Forwarded-Host if
// present and the request is proxied.
func GetRequestHost(req *http.Request) string {
host := req.Header.Get("X-Forwarded-Host")
if !IsProxied(req) || host == "" {
host = req.Host
}
return host
}
// GetRequestURI return the request URI or X-Forwarded-Uri if present and the
// request is proxied.
func GetRequestURI(req *http.Request) string {
uri := req.Header.Get("X-Forwarded-Uri")
if !IsProxied(req) || uri == "" {
// Use RequestURI to preserve ?query
uri = req.URL.RequestURI()
}
return uri
}
// IsProxied determines if a request was from a proxy based on the RequestScope
// ReverseProxy tracker.
func IsProxied(req *http.Request) bool {
scope := middlewareapi.GetRequestScope(req)
if scope == nil {
return false
}
return scope.ReverseProxy
}

View File

@ -0,0 +1,19 @@
package util_test
import (
"testing"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
// TestRequestUtilSuite and related tests are in a *_test package
// to prevent circular imports with the `logger` package which uses
// this functionality
func TestRequestUtilSuite(t *testing.T) {
logger.SetOutput(GinkgoWriter)
RegisterFailHandler(Fail)
RunSpecs(t, "Request Utils")
}

View File

@ -0,0 +1,131 @@
package util_test
import (
"fmt"
"net/http"
"net/http/httptest"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Util Suite", func() {
const (
proto = "http"
host = "www.oauth2proxy.test"
uri = "/test/endpoint"
)
var req *http.Request
BeforeEach(func() {
req = httptest.NewRequest(
http.MethodGet,
fmt.Sprintf("%s://%s%s", proto, host, uri),
nil,
)
})
Context("GetRequestHost", func() {
Context("IsProxied is false", func() {
BeforeEach(func() {
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
})
It("returns the host", func() {
Expect(util.GetRequestHost(req)).To(Equal(host))
})
It("ignores X-Forwarded-Host and returns the host", func() {
req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text")
Expect(util.GetRequestHost(req)).To(Equal(host))
})
})
Context("IsProxied is true", func() {
BeforeEach(func() {
req = middleware.AddRequestScope(req, &middleware.RequestScope{
ReverseProxy: true,
})
})
It("returns the host if X-Forwarded-Host is not present", func() {
Expect(util.GetRequestHost(req)).To(Equal(host))
})
It("returns the X-Forwarded-Host when present", func() {
req.Header.Add("X-Forwarded-Host", "external.oauth2proxy.text")
Expect(util.GetRequestHost(req)).To(Equal("external.oauth2proxy.text"))
})
})
})
Context("GetRequestProto", func() {
Context("IsProxied is false", func() {
BeforeEach(func() {
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
})
It("returns the scheme", func() {
Expect(util.GetRequestProto(req)).To(Equal(proto))
})
It("ignores X-Forwarded-Proto and returns the scheme", func() {
req.Header.Add("X-Forwarded-Proto", "https")
Expect(util.GetRequestProto(req)).To(Equal(proto))
})
})
Context("IsProxied is true", func() {
BeforeEach(func() {
req = middleware.AddRequestScope(req, &middleware.RequestScope{
ReverseProxy: true,
})
})
It("returns the scheme if X-Forwarded-Proto is not present", func() {
Expect(util.GetRequestProto(req)).To(Equal(proto))
})
It("returns the X-Forwarded-Proto when present", func() {
req.Header.Add("X-Forwarded-Proto", "https")
Expect(util.GetRequestProto(req)).To(Equal("https"))
})
})
})
Context("GetRequestURI", func() {
Context("IsProxied is false", func() {
BeforeEach(func() {
req = middleware.AddRequestScope(req, &middleware.RequestScope{})
})
It("returns the URI", func() {
Expect(util.GetRequestURI(req)).To(Equal(uri))
})
It("ignores X-Forwarded-Uri and returns the URI", func() {
req.Header.Add("X-Forwarded-Uri", "/some/other/path")
Expect(util.GetRequestURI(req)).To(Equal(uri))
})
})
Context("IsProxied is true", func() {
BeforeEach(func() {
req = middleware.AddRequestScope(req, &middleware.RequestScope{
ReverseProxy: true,
})
})
It("returns the URI if X-Forwarded-Uri is not present", func() {
Expect(util.GetRequestURI(req)).To(Equal(uri))
})
It("returns the X-Forwarded-Uri when present", func() {
req.Header.Add("X-Forwarded-Uri", "/some/other/path")
Expect(util.GetRequestURI(req)).To(Equal("/some/other/path"))
})
})
})
})

View File

@ -5,7 +5,6 @@ import (
"fmt"
"net/http"
"regexp"
"strings"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
@ -220,12 +219,12 @@ func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) {
if len(cookies) == 0 {
return nil, fmt.Errorf("could not find cookie %s", cookieName)
}
return joinCookies(cookies)
return joinCookies(cookies, cookieName)
}
// joinCookies takes a slice of cookies from the request and reconstructs the
// full session cookie
func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) {
func joinCookies(cookies []*http.Cookie, cookieName string) (*http.Cookie, error) {
if len(cookies) == 0 {
return nil, fmt.Errorf("list of cookies must be > 0")
}
@ -236,7 +235,7 @@ func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) {
for i := 1; i < len(cookies); i++ {
c.Value += cookies[i].Value
}
c.Name = strings.TrimRight(c.Name, "_0")
c.Name = cookieName
return c, nil
}

View File

@ -154,9 +154,58 @@ func Test_splitCookie_joinCookies(t *testing.T) {
Value: value,
}
splitCookies := splitCookie(cookie)
joinedCookie, err := joinCookies(splitCookies)
joinedCookie, err := joinCookies(splitCookies, cookie.Name)
assert.NoError(t, err)
assert.Equal(t, *cookie, *joinedCookie)
})
}
}
func Test_joinCookies_withUnderlineSuffix(t *testing.T) {
testCases := map[string]struct {
CookieName string
SplitOrder []int
}{
"Ascending order split with \"_\" suffix": {
CookieName: "_cookie_name_",
SplitOrder: []int{0, 1, 2, 3, 4},
},
"Descending order split with \"_\" suffix": {
CookieName: "_cookie_name_",
SplitOrder: []int{4, 3, 2, 1, 0},
},
"Arbitrary order split with \"_\" suffix": {
CookieName: "_cookie_name_",
SplitOrder: []int{3, 1, 2, 0, 4},
},
"Arbitrary order split with \"_0\" suffix": {
CookieName: "_cookie_name_0",
SplitOrder: []int{1, 3, 0, 2, 4},
},
"Arbitrary order split with \"_1\" suffix": {
CookieName: "_cookie_name_1",
SplitOrder: []int{4, 1, 3, 0, 2},
},
"Arbitrary order split with \"__\" suffix": {
CookieName: "_cookie_name__",
SplitOrder: []int{1, 0, 4, 3, 2},
},
}
for testName, testCase := range testCases {
t.Run(testName, func(t *testing.T) {
cookieName := testCase.CookieName
var splitCookies []*http.Cookie
for _, splitSuffix := range testCase.SplitOrder {
cookie := &http.Cookie{
Name: splitCookieName(cookieName, splitSuffix),
Value: strings.Repeat("v", 1000),
}
splitCookies = append(splitCookies, cookie)
}
joinedCookie, err := joinCookies(splitCookies, cookieName)
assert.NoError(t, err)
assert.Equal(t, cookieName, joinedCookie.Name)
})
}
}

View File

@ -4,7 +4,6 @@ import (
"crypto/x509"
"fmt"
"io/ioutil"
"net/http"
)
func GetCertPool(paths []string) (*x509.CertPool, error) {
@ -24,12 +23,3 @@ func GetCertPool(paths []string) (*x509.CertPool, error) {
}
return pool, nil
}
// GetRequestHost return the request host header or X-Forwarded-Host if present
func GetRequestHost(req *http.Request) string {
host := req.Header.Get("X-Forwarded-Host")
if host == "" {
host = req.Host
}
return host
}

View File

@ -4,11 +4,9 @@ import (
"crypto/x509/pkix"
"encoding/asn1"
"io/ioutil"
"net/http/httptest"
"os"
"testing"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/assert"
)
@ -97,16 +95,3 @@ func TestGetCertPool(t *testing.T) {
expectedSubjects := []string{testCA1Subj, testCA2Subj}
assert.Equal(t, expectedSubjects, got)
}
func TestGetRequestHost(t *testing.T) {
g := NewWithT(t)
req := httptest.NewRequest("GET", "https://example.com", nil)
host := GetRequestHost(req)
g.Expect(host).To(Equal("example.com"))
proxyReq := httptest.NewRequest("GET", "http://internal.example.com", nil)
proxyReq.Header.Add("X-Forwarded-Host", "external.example.com")
extHost := GetRequestHost(proxyReq)
g.Expect(extHost).To(Equal("external.example.com"))
}

View File

@ -233,9 +233,19 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs)
// Make the OIDC Verifier accessible to all providers that can support it
// Make the OIDC options available to all providers that support it
p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail
p.EmailClaim = o.OIDCEmailClaim
p.GroupsClaim = o.OIDCGroupsClaim
p.Verifier = o.GetOIDCVerifier()
// TODO (@NickMeves) - Remove This
// Backwards Compatibility for Deprecated UserIDClaim option
if o.OIDCEmailClaim == providers.OIDCEmailClaim &&
o.UserIDClaim != providers.OIDCEmailClaim {
p.EmailClaim = o.UserIDClaim
}
p.SetAllowedGroups(o.AllowedGroups)
provider := providers.New(o.ProviderType, p)
@ -253,7 +263,10 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
p.SetRepo(o.GitHubRepo, o.GitHubToken)
p.SetUsers(o.GitHubUsers)
case *providers.KeycloakProvider:
p.SetGroup(o.KeycloakGroup)
// Backwards compatibility with `--keycloak-group` option
if len(o.KeycloakGroups) > 0 {
p.SetAllowedGroups(o.KeycloakGroups)
}
case *providers.GoogleProvider:
if o.GoogleServiceAccountJSON != "" {
file, err := os.Open(o.GoogleServiceAccountJSON)
@ -273,14 +286,10 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
p.SetTeam(o.BitbucketTeam)
p.SetRepository(o.BitbucketRepository)
case *providers.OIDCProvider:
p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail
p.UserIDClaim = o.UserIDClaim
p.GroupsClaim = o.OIDCGroupsClaim
if p.Verifier == nil {
msgs = append(msgs, "oidc provider requires an oidc issuer URL")
}
case *providers.GitLabProvider:
p.AllowUnverifiedEmail = o.InsecureOIDCAllowUnverifiedEmail
p.Groups = o.GitLabGroup
err := p.AddProjects(o.GitlabProjects)
if err != nil {

View File

@ -20,8 +20,6 @@ type GitLabProvider struct {
Groups []string
Projects []*GitlabProject
AllowUnverifiedEmail bool
}
// GitlabProject represents a Gitlab project constraint entity
@ -103,7 +101,7 @@ func (p *GitLabProvider) Redeem(ctx context.Context, redirectURL, code string) (
if err != nil {
return nil, fmt.Errorf("token exchange: %v", err)
}
s, err = p.createSessionState(ctx, token)
s, err = p.createSession(ctx, token)
if err != nil {
return nil, fmt.Errorf("unable to update session: %v", err)
}
@ -162,7 +160,7 @@ func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.Ses
if err != nil {
return fmt.Errorf("failed to get token: %v", err)
}
newSession, err := p.createSessionState(ctx, token)
newSession, err := p.createSession(ctx, token)
if err != nil {
return fmt.Errorf("unable to update session: %v", err)
}
@ -255,22 +253,21 @@ func (p *GitLabProvider) AddProjects(projects []string) error {
return nil
}
func (p *GitLabProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("token response did not contain an id_token")
}
// Parse and verify ID Token payload.
idToken, err := p.Verifier.Verify(ctx, rawIDToken)
func (p *GitLabProvider) createSession(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
idToken, err := p.verifyIDToken(ctx, token)
if err != nil {
switch err {
case ErrMissingIDToken:
return nil, fmt.Errorf("token response did not contain an id_token")
default:
return nil, fmt.Errorf("could not verify id_token: %v", err)
}
}
created := time.Now()
return &sessions.SessionState{
AccessToken: token.AccessToken,
IDToken: rawIDToken,
IDToken: getIDToken(token),
RefreshToken: token.RefreshToken,
CreatedAt: &created,
ExpiresOn: &idToken.Expiry,

View File

@ -2,6 +2,7 @@ package providers
import (
"context"
"fmt"
"net/url"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
@ -11,7 +12,6 @@ import (
type KeycloakProvider struct {
*ProviderData
Group string
}
var _ Provider = (*KeycloakProvider)(nil)
@ -47,6 +47,7 @@ var (
}
)
// NewKeycloakProvider creates a KeyCloakProvider using the passed ProviderData
func NewKeycloakProvider(p *ProviderData) *KeycloakProvider {
p.setProviderDefaults(providerDefaults{
name: keycloakProviderName,
@ -59,41 +60,39 @@ func NewKeycloakProvider(p *ProviderData) *KeycloakProvider {
return &KeycloakProvider{ProviderData: p}
}
func (p *KeycloakProvider) SetGroup(group string) {
p.Group = group
}
// EnrichSession uses the Keycloak userinfo endpoint to populate the session's
// email and groups.
func (p *KeycloakProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
// Fallback to ValidateURL if ProfileURL not set for legacy compatibility
profileURL := p.ValidateURL.String()
if p.ProfileURL.String() != "" {
profileURL = p.ProfileURL.String()
}
func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) {
json, err := requests.New(p.ValidateURL.String()).
json, err := requests.New(profileURL).
WithContext(ctx).
SetHeader("Authorization", "Bearer "+s.AccessToken).
Do().
UnmarshalJSON()
if err != nil {
logger.Errorf("failed making request %s", err)
return "", err
logger.Errorf("failed making request %v", err)
return err
}
if p.Group != "" {
var groups, err = json.Get("groups").Array()
groups, err := json.Get("groups").StringArray()
if err == nil {
for _, group := range groups {
if group != "" {
s.Groups = append(s.Groups, group)
}
}
}
email, err := json.Get("email").String()
if err != nil {
logger.Printf("groups not found %s", err)
return "", err
return fmt.Errorf("unable to extract email from userinfo endpoint: %v", err)
}
s.Email = email
var found = false
for i := range groups {
if groups[i].(string) == p.Group {
found = true
break
}
}
if !found {
logger.Printf("group not found, access denied")
return "", nil
}
}
return json.Get("email").String()
return nil
}

View File

@ -2,17 +2,24 @@ package providers
import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/assert"
)
func testKeycloakProvider(hostname, group string) *KeycloakProvider {
const (
keycloakAccessToken = "eyJKeycloak.eyJAccess.Token"
keycloakUserinfoPath = "/api/v3/user"
)
func testKeycloakProvider(backend *httptest.Server) (*KeycloakProvider, error) {
p := NewKeycloakProvider(
&ProviderData{
ProviderName: "",
@ -22,63 +29,35 @@ func testKeycloakProvider(hostname, group string) *KeycloakProvider {
ValidateURL: &url.URL{},
Scope: ""})
if group != "" {
p.SetGroup(group)
if backend != nil {
bURL, err := url.Parse(backend.URL)
if err != nil {
return nil, err
}
hostname := bURL.Host
if hostname != "" {
updateURL(p.Data().LoginURL, hostname)
updateURL(p.Data().RedeemURL, hostname)
updateURL(p.Data().ProfileURL, hostname)
updateURL(p.Data().ValidateURL, hostname)
}
return p
return p, nil
}
func testKeycloakBackend(payload string) *httptest.Server {
path := "/api/v3/user"
return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
url := r.URL
if url.Path != path {
w.WriteHeader(404)
} else if !IsAuthorizedInHeader(r.Header) {
w.WriteHeader(403)
} else {
w.WriteHeader(200)
w.Write([]byte(payload))
}
}))
}
func TestKeycloakProviderDefaults(t *testing.T) {
p := testKeycloakProvider("", "")
assert.NotEqual(t, nil, p)
assert.Equal(t, "Keycloak", p.Data().ProviderName)
assert.Equal(t, "https://keycloak.org/oauth/authorize",
p.Data().LoginURL.String())
assert.Equal(t, "https://keycloak.org/oauth/token",
p.Data().RedeemURL.String())
assert.Equal(t, "https://keycloak.org/api/v3/user",
p.Data().ValidateURL.String())
assert.Equal(t, "api", p.Data().Scope)
}
func TestNewKeycloakProvider(t *testing.T) {
g := NewWithT(t)
// Test that defaults are set when calling for a new provider with nothing set
var _ = Describe("Keycloak Provider Tests", func() {
Context("New Provider Init", func() {
It("uses defaults", func() {
providerData := NewKeycloakProvider(&ProviderData{}).Data()
g.Expect(providerData.ProviderName).To(Equal("Keycloak"))
g.Expect(providerData.LoginURL.String()).To(Equal("https://keycloak.org/oauth/authorize"))
g.Expect(providerData.RedeemURL.String()).To(Equal("https://keycloak.org/oauth/token"))
g.Expect(providerData.ProfileURL.String()).To(Equal(""))
g.Expect(providerData.ValidateURL.String()).To(Equal("https://keycloak.org/api/v3/user"))
g.Expect(providerData.Scope).To(Equal("api"))
}
Expect(providerData.ProviderName).To(Equal("Keycloak"))
Expect(providerData.LoginURL.String()).To(Equal("https://keycloak.org/oauth/authorize"))
Expect(providerData.RedeemURL.String()).To(Equal("https://keycloak.org/oauth/token"))
Expect(providerData.ProfileURL.String()).To(Equal(""))
Expect(providerData.ValidateURL.String()).To(Equal("https://keycloak.org/api/v3/user"))
Expect(providerData.Scope).To(Equal("api"))
})
func TestKeycloakProviderOverrides(t *testing.T) {
It("overrides defaults", func() {
p := NewKeycloakProvider(
&ProviderData{
LoginURL: &url.URL{
@ -89,75 +68,143 @@ func TestKeycloakProviderOverrides(t *testing.T) {
Scheme: "https",
Host: "example.com",
Path: "/oauth/token"},
ProfileURL: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/api/v3/user"},
ValidateURL: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/api/v3/user"},
Scope: "profile"})
assert.NotEqual(t, nil, p)
assert.Equal(t, "Keycloak", p.Data().ProviderName)
assert.Equal(t, "https://example.com/oauth/auth",
p.Data().LoginURL.String())
assert.Equal(t, "https://example.com/oauth/token",
p.Data().RedeemURL.String())
assert.Equal(t, "https://example.com/api/v3/user",
p.Data().ValidateURL.String())
assert.Equal(t, "profile", p.Data().Scope)
}
providerData := p.Data()
func TestKeycloakProviderGetEmailAddress(t *testing.T) {
b := testKeycloakBackend("{\"email\": \"michael.bland@gsa.gov\"}")
defer b.Close()
Expect(providerData.ProviderName).To(Equal("Keycloak"))
Expect(providerData.LoginURL.String()).To(Equal("https://example.com/oauth/auth"))
Expect(providerData.RedeemURL.String()).To(Equal("https://example.com/oauth/token"))
Expect(providerData.ProfileURL.String()).To(Equal("https://example.com/api/v3/user"))
Expect(providerData.ValidateURL.String()).To(Equal("https://example.com/api/v3/user"))
Expect(providerData.Scope).To(Equal("profile"))
})
})
bURL, _ := url.Parse(b.URL)
p := testKeycloakProvider(bURL.Host, "")
Context("EnrichSession", func() {
type enrichSessionTableInput struct {
backendHandler http.HandlerFunc
expectedError error
expectedEmail string
expectedGroups []string
}
session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
}
DescribeTable("should return expected results",
func(in enrichSessionTableInput) {
backend := httptest.NewServer(in.backendHandler)
p, err := testKeycloakProvider(backend)
Expect(err).To(BeNil())
func TestKeycloakProviderGetEmailAddressAndGroup(t *testing.T) {
b := testKeycloakBackend("{\"email\": \"michael.bland@gsa.gov\", \"groups\": [\"test-grp1\", \"test-grp2\"]}")
defer b.Close()
p.ProfileURL, err = url.Parse(
fmt.Sprintf("%s%s", backend.URL, keycloakUserinfoPath),
)
Expect(err).To(BeNil())
bURL, _ := url.Parse(b.URL)
p := testKeycloakProvider(bURL.Host, "test-grp1")
session := &sessions.SessionState{AccessToken: keycloakAccessToken}
err = p.EnrichSession(context.Background(), session)
session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
}
if in.expectedError != nil {
Expect(err).To(Equal(in.expectedError))
} else {
Expect(err).To(BeNil())
}
// Note that trying to trigger the "failed building request" case is not
// practical, since the only way it can fail is if the URL fails to parse.
func TestKeycloakProviderGetEmailAddressFailedRequest(t *testing.T) {
b := testKeycloakBackend("unused payload")
defer b.Close()
Expect(session.Email).To(Equal(in.expectedEmail))
bURL, _ := url.Parse(b.URL)
p := testKeycloakProvider(bURL.Host, "")
// We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as
// JSON to fail.
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
}
func TestKeycloakProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b := testKeycloakBackend("{\"foo\": \"bar\"}")
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testKeycloakProvider(bURL.Host, "")
session := CreateAuthorizedSession()
email, err := p.GetEmailAddress(context.Background(), session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
}
if in.expectedGroups != nil {
Expect(session.Groups).To(Equal(in.expectedGroups))
} else {
Expect(session.Groups).To(BeNil())
}
},
Entry("email and multiple groups", enrichSessionTableInput{
backendHandler: func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(200)
_, err := w.Write([]byte(`
{
"email": "michael.bland@gsa.gov",
"groups": [
"test-grp1",
"test-grp2"
]
}
`))
if err != nil {
panic(err)
}
},
expectedError: nil,
expectedEmail: "michael.bland@gsa.gov",
expectedGroups: []string{"test-grp1", "test-grp2"},
}),
Entry("email and single group", enrichSessionTableInput{
backendHandler: func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(200)
_, err := w.Write([]byte(`
{
"email": "michael.bland@gsa.gov",
"groups": ["test-grp1"]
}
`))
if err != nil {
panic(err)
}
},
expectedError: nil,
expectedEmail: "michael.bland@gsa.gov",
expectedGroups: []string{"test-grp1"},
}),
Entry("email and no groups", enrichSessionTableInput{
backendHandler: func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(200)
_, err := w.Write([]byte(`
{
"email": "michael.bland@gsa.gov"
}
`))
if err != nil {
panic(err)
}
},
expectedError: nil,
expectedEmail: "michael.bland@gsa.gov",
expectedGroups: nil,
}),
Entry("missing email", enrichSessionTableInput{
backendHandler: func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(200)
_, err := w.Write([]byte(`
{
"groups": [
"test-grp1",
"test-grp2"
]
}
`))
if err != nil {
panic(err)
}
},
expectedError: errors.New(
"unable to extract email from userinfo endpoint: type assertion to string failed"),
expectedEmail: "",
expectedGroups: []string{"test-grp1", "test-grp2"},
}),
Entry("request failure", enrichSessionTableInput{
backendHandler: func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(500)
},
expectedError: errors.New(`unexpected status "500": `),
expectedEmail: "",
expectedGroups: nil,
}),
)
})
})

View File

@ -2,29 +2,20 @@ package providers
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"strings"
"time"
oidc "github.com/coreos/go-oidc"
"golang.org/x/oauth2"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
"golang.org/x/oauth2"
)
const emailClaim = "email"
// OIDCProvider represents an OIDC based Identity Provider
type OIDCProvider struct {
*ProviderData
AllowUnverifiedEmail bool
UserIDClaim string
GroupsClaim string
}
// NewOIDCProvider initiates a new OIDCProvider
@ -36,10 +27,10 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider {
var _ Provider = (*OIDCProvider)(nil)
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s *sessions.SessionState, err error) {
func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
clientSecret, err := p.GetClientSecret()
if err != nil {
return
return nil, err
}
c := oauth2.Config{
@ -52,23 +43,74 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s
}
token, err := c.Exchange(ctx, code)
if err != nil {
return nil, fmt.Errorf("token exchange: %v", err)
return nil, fmt.Errorf("token exchange failed: %v", err)
}
// in the initial exchange the id token is mandatory
idToken, err := p.findVerifiedIDToken(ctx, token)
return p.createSession(ctx, token, false)
}
// EnrichSession is called after Redeem to allow providers to enrich session fields
// such as User, Email, Groups with provider specific API calls.
func (p *OIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
if p.ProfileURL.String() == "" {
if s.Email == "" {
return errors.New("id_token did not contain an email and profileURL is not defined")
}
return nil
}
// Try to get missing emails or groups from a profileURL
if s.Email == "" || s.Groups == nil {
err := p.enrichFromProfileURL(ctx, s)
if err != nil {
return nil, fmt.Errorf("could not verify id_token: %v", err)
} else if idToken == nil {
return nil, fmt.Errorf("token response did not contain an id_token")
logger.Errorf("Warning: Profile URL request failed: %v", err)
}
}
s, err = p.createSessionState(ctx, token, idToken)
// If a mandatory email wasn't set, error at this point.
if s.Email == "" {
return errors.New("neither the id_token nor the profileURL set an email")
}
return nil
}
// enrichFromProfileURL enriches a session's Email & Groups via the JSON response of
// an OIDC profile URL
func (p *OIDCProvider) enrichFromProfileURL(ctx context.Context, s *sessions.SessionState) error {
respJSON, err := requests.New(p.ProfileURL.String()).
WithContext(ctx).
WithHeaders(makeOIDCHeader(s.AccessToken)).
Do().
UnmarshalJSON()
if err != nil {
return nil, fmt.Errorf("unable to update session: %v", err)
return err
}
return
email, err := respJSON.Get(p.EmailClaim).String()
if err == nil && s.Email == "" {
s.Email = email
}
if len(s.Groups) > 0 {
return nil
}
for _, group := range coerceArray(respJSON, p.GroupsClaim) {
formatted, err := formatGroup(group)
if err != nil {
logger.Errorf("Warning: unable to format group of type %s with error %s",
reflect.TypeOf(group), err)
continue
}
s.Groups = append(s.Groups, formatted)
}
return nil
}
// ValidateSession checks that the session's IDToken is still valid
func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
_, err := p.Verifier.Verify(ctx, s.IDToken)
return err == nil
}
// RefreshSessionIfNeeded checks if the session has expired and uses the
@ -83,14 +125,16 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
}
fmt.Printf("refreshed access token %s (expired on %s)\n", s, s.ExpiresOn)
logger.Printf("refreshed session: %s", s)
return true, nil
}
func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) {
// redeemRefreshToken uses a RefreshToken with the RedeemURL to refresh the
// Access Token and (probably) the ID Token.
func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error {
clientSecret, err := p.GetClientSecret()
if err != nil {
return
return err
}
c := oauth2.Config{
@ -109,19 +153,14 @@ func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sessi
return fmt.Errorf("failed to get token: %v", err)
}
// in the token refresh response the id_token is optional
idToken, err := p.findVerifiedIDToken(ctx, token)
if err != nil {
return fmt.Errorf("unable to extract id_token from response: %v", err)
}
newSession, err := p.createSessionState(ctx, token, idToken)
newSession, err := p.createSession(ctx, token, true)
if err != nil {
return fmt.Errorf("unable create new session state from response: %v", err)
}
// It's possible that if the refresh token isn't in the token response the session will not contain an id token
// if it doesn't it's probably better to retain the old one
// It's possible that if the refresh token isn't in the token response the
// session will not contain an id token.
// If it doesn't it's probably better to retain the old one
if newSession.IDToken != "" {
s.IDToken = newSession.IDToken
s.Email = newSession.Email
@ -135,193 +174,62 @@ func (p *OIDCProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sessi
s.CreatedAt = newSession.CreatedAt
s.ExpiresOn = newSession.ExpiresOn
return
}
func (p *OIDCProvider) findVerifiedIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
getIDToken := func() (string, bool) {
rawIDToken, _ := token.Extra("id_token").(string)
return rawIDToken, len(strings.TrimSpace(rawIDToken)) > 0
}
if rawIDToken, present := getIDToken(); present {
verifiedIDToken, err := p.Verifier.Verify(ctx, rawIDToken)
return verifiedIDToken, err
}
return nil, nil
}
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token, idToken *oidc.IDToken) (*sessions.SessionState, error) {
var newSession *sessions.SessionState
if idToken == nil {
newSession = &sessions.SessionState{}
} else {
var err error
newSession, err = p.createSessionStateInternal(ctx, idToken, token)
if err != nil {
return nil, err
}
}
created := time.Now()
newSession.AccessToken = token.AccessToken
newSession.RefreshToken = token.RefreshToken
newSession.CreatedAt = &created
newSession.ExpiresOn = &token.Expiry
return newSession, nil
return nil
}
// CreateSessionFromToken converts Bearer IDTokens into sessions
func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) {
idToken, err := p.Verifier.Verify(ctx, token)
if err != nil {
return nil, err
}
newSession, err := p.createSessionStateInternal(ctx, idToken, nil)
ss, err := p.buildSessionFromClaims(idToken)
if err != nil {
return nil, err
}
newSession.AccessToken = token
newSession.IDToken = token
newSession.RefreshToken = ""
newSession.ExpiresOn = &idToken.Expiry
return newSession, nil
}
func (p *OIDCProvider) createSessionStateInternal(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*sessions.SessionState, error) {
newSession := &sessions.SessionState{}
if idToken == nil {
return newSession, nil
// Allow empty Email in Bearer case since we can't hit the ProfileURL
if ss.Email == "" {
ss.Email = ss.User
}
claims, err := p.findClaimsFromIDToken(ctx, idToken, token)
ss.AccessToken = token
ss.IDToken = token
ss.RefreshToken = ""
ss.ExpiresOn = &idToken.Expiry
return ss, nil
}
// createSession takes an oauth2.Token and creates a SessionState from it.
// It alters behavior if called from Redeem vs Refresh
func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, refresh bool) (*sessions.SessionState, error) {
idToken, err := p.verifyIDToken(ctx, token)
if err != nil {
return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err)
switch err {
case ErrMissingIDToken:
// IDToken is mandatory in Redeem but optional in Refresh
if !refresh {
return nil, errors.New("token response did not contain an id_token")
}
default:
return nil, fmt.Errorf("could not verify id_token: %v", err)
}
}
if token != nil {
newSession.IDToken = token.Extra("id_token").(string)
}
newSession.Email = claims.UserID // TODO Rename SessionState.Email to .UserID in the near future
newSession.User = claims.Subject
newSession.Groups = claims.Groups
newSession.PreferredUsername = claims.PreferredUsername
verifyEmail := (p.UserIDClaim == emailClaim) && !p.AllowUnverifiedEmail
if verifyEmail && claims.Verified != nil && !*claims.Verified {
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.UserID)
}
return newSession, nil
}
// ValidateSessionState checks that the session's IDToken is still valid
func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
_, err := p.Verifier.Verify(ctx, s.IDToken)
return err == nil
}
func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*OIDCClaims, error) {
claims := &OIDCClaims{}
// Extract default claims.
if err := idToken.Claims(&claims); err != nil {
return nil, fmt.Errorf("failed to parse default id_token claims: %v", err)
}
// Extract custom claims.
if err := idToken.Claims(&claims.rawClaims); err != nil {
return nil, fmt.Errorf("failed to parse all id_token claims: %v", err)
}
userID := claims.rawClaims[p.UserIDClaim]
if userID != nil {
claims.UserID = fmt.Sprint(userID)
}
claims.Groups = p.extractGroupsFromRawClaims(claims.rawClaims)
// userID claim was not present or was empty in the ID Token
if claims.UserID == "" {
// BearerToken case, allow empty UserID
// ProfileURL checks below won't work since we don't have an access token
if token == nil {
claims.UserID = claims.Subject
return claims, nil
}
profileURL := p.ProfileURL.String()
if profileURL == "" || token.AccessToken == "" {
return nil, fmt.Errorf("id_token did not contain user ID claim (%q)", p.UserIDClaim)
}
// If the userinfo endpoint profileURL is defined, then there is a chance the userinfo
// contents at the profileURL contains the email.
// Make a query to the userinfo endpoint, and attempt to locate the email from there.
respJSON, err := requests.New(profileURL).
WithContext(ctx).
WithHeaders(makeOIDCHeader(token.AccessToken)).
Do().
UnmarshalJSON()
ss, err := p.buildSessionFromClaims(idToken)
if err != nil {
return nil, err
}
userID, err := respJSON.Get(p.UserIDClaim).String()
if err != nil {
return nil, fmt.Errorf("neither id_token nor userinfo endpoint contained user ID claim (%q)", p.UserIDClaim)
}
ss.AccessToken = token.AccessToken
ss.RefreshToken = token.RefreshToken
ss.IDToken = getIDToken(token)
claims.UserID = userID
}
created := time.Now()
ss.CreatedAt = &created
ss.ExpiresOn = &token.Expiry
return claims, nil
}
func (p *OIDCProvider) extractGroupsFromRawClaims(rawClaims map[string]interface{}) []string {
groups := []string{}
rawGroups, ok := rawClaims[p.GroupsClaim].([]interface{})
if rawGroups != nil && ok {
for _, rawGroup := range rawGroups {
formattedGroup, err := formatGroup(rawGroup)
if err != nil {
logger.Errorf("unable to format group of type %s with error %s", reflect.TypeOf(rawGroup), err)
continue
}
groups = append(groups, formattedGroup)
}
}
return groups
}
func formatGroup(rawGroup interface{}) (string, error) {
group, ok := rawGroup.(string)
if !ok {
jsonGroup, err := json.Marshal(rawGroup)
if err != nil {
return "", err
}
group = string(jsonGroup)
}
return group, nil
}
type OIDCClaims struct {
rawClaims map[string]interface{}
UserID string
Subject string `json:"sub"`
Verified *bool `json:"email_verified"`
PreferredUsername string `json:"preferred_username"`
Groups []string `json:"-"`
return ss, nil
}

View File

@ -2,42 +2,18 @@ package providers
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/coreos/go-oidc"
"github.com/dgrijalva/jwt-go"
"github.com/stretchr/testify/assert"
"golang.org/x/oauth2"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/stretchr/testify/assert"
)
const accessToken = "access_token"
const refreshToken = "refresh_token"
const clientID = "https://test.myapp.com"
const secret = "secret"
type idTokenClaims struct {
Name string `json:"name,omitempty"`
Email string `json:"email,omitempty"`
Phone string `json:"phone_number,omitempty"`
Picture string `json:"picture,omitempty"`
Groups interface{} `json:"groups,omitempty"`
OtherGroups interface{} `json:"other_groups,omitempty"`
jwt.StandardClaims
}
type redeemTokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
@ -46,88 +22,12 @@ type redeemTokenResponse struct {
IDToken string `json:"id_token,omitempty"`
}
var defaultIDToken idTokenClaims = idTokenClaims{
"Jane Dobbs",
"janed@me.com",
"+4798765432",
"http://mugbook.com/janed/me.jpg",
[]string{"test:a", "test:b"},
[]string{"test:c", "test:d"},
jwt.StandardClaims{
Audience: "https://test.myapp.com",
ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(),
Id: "id-some-id",
IssuedAt: time.Now().Unix(),
Issuer: "https://issuer.example.com",
NotBefore: 0,
Subject: "123456789",
},
}
var customGroupClaimIDToken idTokenClaims = idTokenClaims{
"Jane Dobbs",
"janed@me.com",
"+4798765432",
"http://mugbook.com/janed/me.jpg",
[]map[string]interface{}{
{
"groupId": "Admin Group Id",
"roles": []string{"Admin"},
},
},
[]string{"test:c", "test:d"},
jwt.StandardClaims{
Audience: "https://test.myapp.com",
ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(),
Id: "id-some-id",
IssuedAt: time.Now().Unix(),
Issuer: "https://issuer.example.com",
NotBefore: 0,
Subject: "123456789",
},
}
var minimalIDToken idTokenClaims = idTokenClaims{
"",
"",
"",
"",
[]string{},
[]string{},
jwt.StandardClaims{
Audience: "https://test.myapp.com",
ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(),
Id: "id-some-id",
IssuedAt: time.Now().Unix(),
Issuer: "https://issuer.example.com",
NotBefore: 0,
Subject: "minimal",
},
}
type fakeKeySetStub struct{}
func (fakeKeySetStub) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) {
decodeString, err := base64.RawURLEncoding.DecodeString(strings.Split(jwt, ".")[1])
if err != nil {
return nil, err
}
tokenClaims := &idTokenClaims{}
err = json.Unmarshal(decodeString, tokenClaims)
if err != nil || tokenClaims.Id == "this-id-fails-validation" {
return nil, fmt.Errorf("the validation failed for subject [%v]", tokenClaims.Subject)
}
return decodeString, err
}
func newOIDCProvider(serverURL *url.URL) *OIDCProvider {
providerData := &ProviderData{
ProviderName: "oidc",
ClientID: clientID,
ClientSecret: secret,
ClientID: oidcClientID,
ClientSecret: oidcSecret,
LoginURL: &url.URL{
Scheme: serverURL.Scheme,
Host: serverURL.Host,
@ -145,17 +45,16 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider {
Host: serverURL.Host,
Path: "/api"},
Scope: "openid profile offline_access",
EmailClaim: "email",
GroupsClaim: "groups",
Verifier: oidc.NewVerifier(
"https://issuer.example.com",
fakeKeySetStub{},
&oidc.Config{ClientID: clientID},
oidcIssuer,
mockJWKS{},
&oidc.Config{ClientID: oidcClientID},
),
}
p := &OIDCProvider{
ProviderData: providerData,
UserIDClaim: "email",
}
p := &OIDCProvider{ProviderData: providerData}
return p
}
@ -169,22 +68,7 @@ func newOIDCServer(body []byte) (*url.URL, *httptest.Server) {
return u, s
}
func newSignedTestIDToken(tokenClaims idTokenClaims) (string, error) {
key, _ := rsa.GenerateKey(rand.Reader, 2048)
standardClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, tokenClaims)
return standardClaims.SignedString(key)
}
func newOauth2Token() *oauth2.Token {
return &oauth2.Token{
AccessToken: accessToken,
TokenType: "Bearer",
RefreshToken: refreshToken,
Expiry: time.Time{}.Add(time.Duration(5) * time.Second),
}
}
func newTestSetup(body []byte) (*httptest.Server, *OIDCProvider) {
func newTestOIDCSetup(body []byte) (*httptest.Server, *OIDCProvider) {
redeemURL, server := newOIDCServer(body)
provider := newOIDCProvider(redeemURL)
return server, provider
@ -201,7 +85,7 @@ func TestOIDCProviderRedeem(t *testing.T) {
IDToken: idToken,
})
server, provider := newTestSetup(body)
server, provider := newTestOIDCSetup(body)
defer server.Close()
session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234")
@ -224,8 +108,8 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) {
IDToken: idToken,
})
server, provider := newTestSetup(body)
provider.UserIDClaim = "phone_number"
server, provider := newTestOIDCSetup(body)
provider.EmailClaim = "phone_number"
defer server.Close()
session, err := provider.Redeem(context.Background(), provider.RedeemURL.String(), "code1234")
@ -233,6 +117,333 @@ func TestOIDCProviderRedeem_custom_userid(t *testing.T) {
assert.Equal(t, defaultIDToken.Phone, session.Email)
}
func TestOIDCProvider_EnrichSession(t *testing.T) {
testCases := map[string]struct {
ExistingSession *sessions.SessionState
EmailClaim string
GroupsClaim string
ProfileJSON map[string]interface{}
ExpectedError error
ExpectedSession *sessions.SessionState
}{
"Already Populated": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": []string{"new", "thing"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Email": {
ExistingSession: &sessions.SessionState{
User: "missing.email",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "found@email.com",
"groups": []string{"new", "thing"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "missing.email",
Email: "found@email.com",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Email Only in Profile URL": {
ExistingSession: &sessions.SessionState{
User: "missing.email",
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "found@email.com",
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "missing.email",
Email: "found@email.com",
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Email with Custom Claim": {
ExistingSession: &sessions.SessionState{
User: "missing.email",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "weird",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"weird": "weird@claim.com",
"groups": []string{"new", "thing"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "missing.email",
Email: "weird@claim.com",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Email not in Profile URL": {
ExistingSession: &sessions.SessionState{
User: "missing.email",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"groups": []string{"new", "thing"},
},
ExpectedError: errors.New("neither the id_token nor the profileURL set an email"),
ExpectedSession: &sessions.SessionState{
User: "missing.email",
Groups: []string{"already", "populated"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": []string{"new", "thing"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"new", "thing"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups with Complex Groups in Profile URL": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": []map[string]interface{}{
{
"groupId": "Admin Group Id",
"roles": []string{"Admin"},
},
},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups with Singleton Complex Group in Profile URL": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": map[string]interface{}{
"groupId": "Admin Group Id",
"roles": []string{"Admin"},
},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Empty Groups Claims": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": []string{"new", "thing"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups with Custom Claim": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "roles",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"roles": []string{"new", "thing", "roles"},
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"new", "thing", "roles"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups String Profile URL Response": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: nil,
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
"groups": "singleton",
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
Groups: []string{"singleton"},
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
"Missing Groups in both Claims and Profile URL": {
ExistingSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
EmailClaim: "email",
GroupsClaim: "groups",
ProfileJSON: map[string]interface{}{
"email": "new@thing.com",
},
ExpectedError: nil,
ExpectedSession: &sessions.SessionState{
User: "already",
Email: "already@populated.com",
IDToken: idToken,
AccessToken: accessToken,
RefreshToken: refreshToken,
},
},
}
for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
jsonResp, err := json.Marshal(tc.ProfileJSON)
assert.NoError(t, err)
server, provider := newTestOIDCSetup(jsonResp)
provider.ProfileURL, err = url.Parse(server.URL)
assert.NoError(t, err)
provider.EmailClaim = tc.EmailClaim
provider.GroupsClaim = tc.GroupsClaim
defer server.Close()
err = provider.EnrichSession(context.Background(), tc.ExistingSession)
assert.Equal(t, tc.ExpectedError, err)
assert.Equal(t, *tc.ExpectedSession, *tc.ExistingSession)
})
}
}
func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
idToken, _ := newSignedTestIDToken(defaultIDToken)
@ -243,7 +454,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
RefreshToken: refreshToken,
})
server, provider := newTestSetup(body)
server, provider := newTestOIDCSetup(body)
defer server.Close()
existingSession := &sessions.SessionState{
@ -277,7 +488,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
IDToken: idToken,
})
server, provider := newTestSetup(body)
server, provider := newTestOIDCSetup(body)
defer server.Close()
existingSession := &sessions.SessionState{
@ -300,48 +511,45 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
}
func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
const profileURLEmail = "janed@me.com"
testCases := map[string]struct {
IDToken idTokenClaims
GroupsClaim string
ExpectedUser string
ExpectedEmail string
ExpectedGroups interface{}
ExpectedGroups []string
}{
"Default IDToken": {
IDToken: defaultIDToken,
GroupsClaim: "groups",
ExpectedUser: defaultIDToken.Subject,
ExpectedEmail: defaultIDToken.Email,
ExpectedUser: "123456789",
ExpectedEmail: "janed@me.com",
ExpectedGroups: []string{"test:a", "test:b"},
},
"Minimal IDToken with no email claim": {
IDToken: minimalIDToken,
GroupsClaim: "groups",
ExpectedUser: minimalIDToken.Subject,
ExpectedEmail: minimalIDToken.Subject,
ExpectedGroups: []string{},
ExpectedUser: "123456789",
ExpectedEmail: "123456789",
ExpectedGroups: nil,
},
"Custom Groups Claim": {
IDToken: defaultIDToken,
GroupsClaim: "other_groups",
ExpectedUser: defaultIDToken.Subject,
ExpectedEmail: defaultIDToken.Email,
GroupsClaim: "roles",
ExpectedUser: "123456789",
ExpectedEmail: "janed@me.com",
ExpectedGroups: []string{"test:c", "test:d"},
},
"Custom Groups Claim2": {
IDToken: customGroupClaimIDToken,
"Complex Groups Claim": {
IDToken: complexGroupsIDToken,
GroupsClaim: "groups",
ExpectedUser: customGroupClaimIDToken.Subject,
ExpectedEmail: customGroupClaimIDToken.Email,
ExpectedUser: "123456789",
ExpectedEmail: "complex@claims.com",
ExpectedGroups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
},
}
for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
jsonResp := []byte(fmt.Sprintf(`{"email":"%s"}`, profileURLEmail))
server, provider := newTestSetup(jsonResp)
server, provider := newTestOIDCSetup([]byte(`{}`))
provider.GroupsClaim = tc.GroupsClaim
defer server.Close()
@ -353,75 +561,10 @@ func TestOIDCProviderCreateSessionFromToken(t *testing.T) {
assert.Equal(t, tc.ExpectedUser, ss.User)
assert.Equal(t, tc.ExpectedEmail, ss.Email)
assert.Equal(t, tc.ExpectedGroups, ss.Groups)
assert.Equal(t, rawIDToken, ss.IDToken)
assert.Equal(t, rawIDToken, ss.AccessToken)
assert.Equal(t, tc.ExpectedGroups, ss.Groups)
assert.Equal(t, "", ss.RefreshToken)
})
}
}
func TestOIDCProvider_findVerifiedIdToken(t *testing.T) {
server, provider := newTestSetup([]byte(""))
defer server.Close()
token := newOauth2Token()
signedIDToken, _ := newSignedTestIDToken(defaultIDToken)
tokenWithIDToken := token.WithExtra(map[string]interface{}{
"id_token": signedIDToken,
})
verifiedIDToken, err := provider.findVerifiedIDToken(context.Background(), tokenWithIDToken)
assert.Equal(t, true, err == nil)
if verifiedIDToken == nil {
t.Fatal("verifiedIDToken is nil")
}
assert.Equal(t, defaultIDToken.Issuer, verifiedIDToken.Issuer)
assert.Equal(t, defaultIDToken.Subject, verifiedIDToken.Subject)
// When the validation fails the response should be nil
defaultIDToken.Id = "this-id-fails-validation"
signedIDToken, _ = newSignedTestIDToken(defaultIDToken)
tokenWithIDToken = token.WithExtra(map[string]interface{}{
"id_token": signedIDToken,
})
verifiedIDToken, err = provider.findVerifiedIDToken(context.Background(), tokenWithIDToken)
assert.Equal(t, errors.New("failed to verify signature: the validation failed for subject [123456789]"), err)
assert.Equal(t, true, verifiedIDToken == nil)
// When there is no id token in the oauth token
verifiedIDToken, err = provider.findVerifiedIDToken(context.Background(), newOauth2Token())
assert.Equal(t, nil, err)
assert.Equal(t, true, verifiedIDToken == nil)
}
func Test_formatGroup(t *testing.T) {
testCases := map[string]struct {
RawGroup interface{}
ExpectedFormattedGroupValue string
}{
"String Group": {
RawGroup: "group",
ExpectedFormattedGroupValue: "group",
},
"Map Group": {
RawGroup: map[string]string{"id": "1", "name": "Test"},
ExpectedFormattedGroupValue: "{\"id\":\"1\",\"name\":\"Test\"}",
},
"List Group": {
RawGroup: []string{"First", "Second"},
ExpectedFormattedGroupValue: "[\"First\",\"Second\"]",
},
}
for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
formattedGroup, err := formatGroup(tc.RawGroup)
assert.Nil(t, err)
assert.Equal(t, tc.ExpectedFormattedGroupValue, formattedGroup)
})
}
}

View File

@ -1,12 +1,23 @@
package providers
import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/url"
"reflect"
"strings"
"github.com/coreos/go-oidc"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"golang.org/x/oauth2"
)
const (
OIDCEmailClaim = "email"
OIDCGroupsClaim = "groups"
)
// ProviderData contains information required to configure all implementations
@ -27,6 +38,11 @@ type ProviderData struct {
ClientSecretFile string
Scope string
Prompt string
// Common OIDC options for any OIDC-based providers to consume
AllowUnverifiedEmail bool
EmailClaim string
GroupsClaim string
Verifier *oidc.IDTokenVerifier
// Universal Group authorization data structure
@ -94,3 +110,116 @@ func defaultURL(u *url.URL, d *url.URL) *url.URL {
}
return &url.URL{}
}
// ****************************************************************************
// These private OIDC helper methods are available to any providers that are
// OIDC compliant
// ****************************************************************************
// OIDCClaims is a struct to unmarshal the OIDC claims from an ID Token payload
type OIDCClaims struct {
Subject string `json:"sub"`
Email string `json:"-"`
Groups []string `json:"-"`
Verified *bool `json:"email_verified"`
raw map[string]interface{}
}
func (p *ProviderData) verifyIDToken(ctx context.Context, token *oauth2.Token) (*oidc.IDToken, error) {
rawIDToken := getIDToken(token)
if strings.TrimSpace(rawIDToken) == "" {
return nil, ErrMissingIDToken
}
if p.Verifier == nil {
return nil, ErrMissingOIDCVerifier
}
return p.Verifier.Verify(ctx, rawIDToken)
}
// buildSessionFromClaims uses IDToken claims to populate a fresh SessionState
// with non-Token related fields.
func (p *ProviderData) buildSessionFromClaims(idToken *oidc.IDToken) (*sessions.SessionState, error) {
ss := &sessions.SessionState{}
if idToken == nil {
return ss, nil
}
claims, err := p.getClaims(idToken)
if err != nil {
return nil, fmt.Errorf("couldn't extract claims from id_token (%v)", err)
}
ss.User = claims.Subject
ss.Email = claims.Email
ss.Groups = claims.Groups
// TODO (@NickMeves) Deprecate for dynamic claim to session mapping
if pref, ok := claims.raw["preferred_username"].(string); ok {
ss.PreferredUsername = pref
}
// `email_verified` must be present and explicitly set to `false` to be
// considered unverified.
verifyEmail := (p.EmailClaim == OIDCEmailClaim) && !p.AllowUnverifiedEmail
if verifyEmail && claims.Verified != nil && !*claims.Verified {
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
}
return ss, nil
}
// getClaims extracts IDToken claims into an OIDCClaims
func (p *ProviderData) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) {
claims := &OIDCClaims{}
// Extract default claims.
if err := idToken.Claims(&claims); err != nil {
return nil, fmt.Errorf("failed to parse default id_token claims: %v", err)
}
// Extract custom claims.
if err := idToken.Claims(&claims.raw); err != nil {
return nil, fmt.Errorf("failed to parse all id_token claims: %v", err)
}
email := claims.raw[p.EmailClaim]
if email != nil {
claims.Email = fmt.Sprint(email)
}
claims.Groups = p.extractGroups(claims.raw)
return claims, nil
}
// extractGroups extracts groups from a claim to a list in a type safe manner.
// If the claim isn't present, `nil` is returned. If the groups claim is
// present but empty, `[]string{}` is returned.
func (p *ProviderData) extractGroups(claims map[string]interface{}) []string {
rawClaim, ok := claims[p.GroupsClaim]
if !ok {
return nil
}
// Handle traditional list-based groups as well as non-standard singleton
// based groups. Both variants support complex objects if needed.
var claimGroups []interface{}
switch raw := rawClaim.(type) {
case []interface{}:
claimGroups = raw
case interface{}:
claimGroups = []interface{}{raw}
}
groups := []string{}
for _, rawGroup := range claimGroups {
formattedGroup, err := formatGroup(rawGroup)
if err != nil {
logger.Errorf("Warning: unable to format group of type %s with error %s",
reflect.TypeOf(rawGroup), err)
continue
}
groups = append(groups, formattedGroup)
}
return groups
}

View File

@ -0,0 +1,437 @@
package providers
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"
"testing"
"time"
"github.com/coreos/go-oidc"
"github.com/dgrijalva/jwt-go"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
. "github.com/onsi/gomega"
"golang.org/x/oauth2"
)
const (
idToken = "eyJfoobar123.eyJbaz987.IDToken"
accessToken = "eyJfoobar123.eyJbaz987.AccessToken"
refreshToken = "eyJfoobar123.eyJbaz987.RefreshToken"
oidcIssuer = "https://issuer.example.com"
oidcClientID = "https://test.myapp.com"
oidcSecret = "SuperSecret123456789"
failureTokenID = "this-id-fails-verification"
)
var (
verified = true
unverified = false
standardClaims = jwt.StandardClaims{
Audience: oidcClientID,
ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(),
Id: "id-some-id",
IssuedAt: time.Now().Unix(),
Issuer: oidcIssuer,
NotBefore: 0,
Subject: "123456789",
}
defaultIDToken = idTokenClaims{
Name: "Jane Dobbs",
Email: "janed@me.com",
Phone: "+4798765432",
Picture: "http://mugbook.com/janed/me.jpg",
Groups: []string{"test:a", "test:b"},
Roles: []string{"test:c", "test:d"},
Verified: &verified,
StandardClaims: standardClaims,
}
complexGroupsIDToken = idTokenClaims{
Name: "Complex Claim",
Email: "complex@claims.com",
Phone: "+5439871234",
Picture: "http://mugbook.com/complex/claims.jpg",
Groups: []map[string]interface{}{
{
"groupId": "Admin Group Id",
"roles": []string{"Admin"},
},
},
Roles: []string{"test:simple", "test:roles"},
Verified: &verified,
StandardClaims: standardClaims,
}
unverifiedIDToken = idTokenClaims{
Name: "Mystery Man",
Email: "unverified@email.com",
Phone: "+4025205729",
Picture: "http://mugbook.com/unverified/email.jpg",
Groups: []string{"test:a", "test:b"},
Roles: []string{"test:c", "test:d"},
Verified: &unverified,
StandardClaims: standardClaims,
}
minimalIDToken = idTokenClaims{
StandardClaims: standardClaims,
}
)
type idTokenClaims struct {
Name string `json:"preferred_username,omitempty"`
Email string `json:"email,omitempty"`
Phone string `json:"phone_number,omitempty"`
Picture string `json:"picture,omitempty"`
Groups interface{} `json:"groups,omitempty"`
Roles interface{} `json:"roles,omitempty"`
Verified *bool `json:"email_verified,omitempty"`
jwt.StandardClaims
}
type mockJWKS struct{}
func (mockJWKS) VerifySignature(_ context.Context, jwt string) ([]byte, error) {
decoded, err := base64.RawURLEncoding.DecodeString(strings.Split(jwt, ".")[1])
if err != nil {
return nil, err
}
tokenClaims := &idTokenClaims{}
err = json.Unmarshal(decoded, tokenClaims)
if err != nil || tokenClaims.Id == failureTokenID {
return nil, fmt.Errorf("the validation failed for subject [%v]", tokenClaims.Subject)
}
return decoded, nil
}
func newSignedTestIDToken(tokenClaims idTokenClaims) (string, error) {
key, _ := rsa.GenerateKey(rand.Reader, 2048)
standardClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, tokenClaims)
return standardClaims.SignedString(key)
}
func newTestOauth2Token() *oauth2.Token {
return &oauth2.Token{
AccessToken: accessToken,
TokenType: "Bearer",
RefreshToken: refreshToken,
Expiry: time.Time{}.Add(time.Duration(5) * time.Second),
}
}
func TestProviderData_verifyIDToken(t *testing.T) {
failureIDToken := defaultIDToken
failureIDToken.Id = failureTokenID
testCases := map[string]struct {
IDToken *idTokenClaims
Verifier bool
ExpectIDToken bool
ExpectedError error
}{
"Valid ID Token": {
IDToken: &defaultIDToken,
Verifier: true,
ExpectIDToken: true,
ExpectedError: nil,
},
"Invalid ID Token": {
IDToken: &failureIDToken,
Verifier: true,
ExpectIDToken: false,
ExpectedError: errors.New("failed to verify signature: the validation failed for subject [123456789]"),
},
"Missing ID Token": {
IDToken: nil,
Verifier: true,
ExpectIDToken: false,
ExpectedError: ErrMissingIDToken,
},
"OIDC Verifier not Configured": {
IDToken: &defaultIDToken,
Verifier: false,
ExpectIDToken: false,
ExpectedError: ErrMissingOIDCVerifier,
},
}
for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
g := NewWithT(t)
token := newTestOauth2Token()
if tc.IDToken != nil {
idToken, err := newSignedTestIDToken(*tc.IDToken)
g.Expect(err).ToNot(HaveOccurred())
token = token.WithExtra(map[string]interface{}{
"id_token": idToken,
})
}
provider := &ProviderData{}
if tc.Verifier {
provider.Verifier = oidc.NewVerifier(
oidcIssuer,
mockJWKS{},
&oidc.Config{ClientID: oidcClientID},
)
}
verified, err := provider.verifyIDToken(context.Background(), token)
if err != nil {
g.Expect(err).To(Equal(tc.ExpectedError))
}
if tc.ExpectIDToken {
g.Expect(verified).ToNot(BeNil())
g.Expect(*verified).To(BeAssignableToTypeOf(oidc.IDToken{}))
} else {
g.Expect(verified).To(BeNil())
}
})
}
}
func TestProviderData_buildSessionFromClaims(t *testing.T) {
testCases := map[string]struct {
IDToken idTokenClaims
AllowUnverified bool
EmailClaim string
GroupsClaim string
ExpectedError error
ExpectedSession *sessions.SessionState
}{
"Standard": {
IDToken: defaultIDToken,
AllowUnverified: false,
EmailClaim: "email",
GroupsClaim: "groups",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "janed@me.com",
Groups: []string{"test:a", "test:b"},
PreferredUsername: "Jane Dobbs",
},
},
"Unverified Denied": {
IDToken: unverifiedIDToken,
AllowUnverified: false,
EmailClaim: "email",
GroupsClaim: "groups",
ExpectedError: errors.New("email in id_token (unverified@email.com) isn't verified"),
},
"Unverified Allowed": {
IDToken: unverifiedIDToken,
AllowUnverified: true,
EmailClaim: "email",
GroupsClaim: "groups",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "unverified@email.com",
Groups: []string{"test:a", "test:b"},
PreferredUsername: "Mystery Man",
},
},
"Complex Groups": {
IDToken: complexGroupsIDToken,
AllowUnverified: true,
EmailClaim: "email",
GroupsClaim: "groups",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "complex@claims.com",
Groups: []string{"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}"},
PreferredUsername: "Complex Claim",
},
},
"Email Claim Switched": {
IDToken: unverifiedIDToken,
AllowUnverified: true,
EmailClaim: "phone_number",
GroupsClaim: "groups",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "+4025205729",
Groups: []string{"test:a", "test:b"},
PreferredUsername: "Mystery Man",
},
},
"Email Claim Switched to Non String": {
IDToken: unverifiedIDToken,
AllowUnverified: true,
EmailClaim: "roles",
GroupsClaim: "groups",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "[test:c test:d]",
Groups: []string{"test:a", "test:b"},
PreferredUsername: "Mystery Man",
},
},
"Email Claim Non Existent": {
IDToken: unverifiedIDToken,
AllowUnverified: true,
EmailClaim: "aksjdfhjksadh",
GroupsClaim: "groups",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "",
Groups: []string{"test:a", "test:b"},
PreferredUsername: "Mystery Man",
},
},
"Groups Claim Switched": {
IDToken: defaultIDToken,
AllowUnverified: false,
EmailClaim: "email",
GroupsClaim: "roles",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "janed@me.com",
Groups: []string{"test:c", "test:d"},
PreferredUsername: "Jane Dobbs",
},
},
"Groups Claim Non Existent": {
IDToken: defaultIDToken,
AllowUnverified: false,
EmailClaim: "email",
GroupsClaim: "alskdjfsalkdjf",
ExpectedSession: &sessions.SessionState{
User: "123456789",
Email: "janed@me.com",
Groups: nil,
PreferredUsername: "Jane Dobbs",
},
},
}
for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
g := NewWithT(t)
provider := &ProviderData{
Verifier: oidc.NewVerifier(
oidcIssuer,
mockJWKS{},
&oidc.Config{ClientID: oidcClientID},
),
}
provider.AllowUnverifiedEmail = tc.AllowUnverified
provider.EmailClaim = tc.EmailClaim
provider.GroupsClaim = tc.GroupsClaim
rawIDToken, err := newSignedTestIDToken(tc.IDToken)
g.Expect(err).ToNot(HaveOccurred())
idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken)
g.Expect(err).ToNot(HaveOccurred())
ss, err := provider.buildSessionFromClaims(idToken)
if err != nil {
g.Expect(err).To(Equal(tc.ExpectedError))
}
if ss != nil {
g.Expect(ss).To(Equal(tc.ExpectedSession))
}
})
}
}
func TestProviderData_extractGroups(t *testing.T) {
testCases := map[string]struct {
Claims map[string]interface{}
GroupsClaim string
ExpectedGroups []string
}{
"Standard String Groups": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
"groups": []interface{}{"three", "string", "groups"},
},
GroupsClaim: "groups",
ExpectedGroups: []string{"three", "string", "groups"},
},
"Different Claim Name": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
"roles": []interface{}{"three", "string", "roles"},
},
GroupsClaim: "roles",
ExpectedGroups: []string{"three", "string", "roles"},
},
"Numeric Groups": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
"groups": []interface{}{1, 2, 3},
},
GroupsClaim: "groups",
ExpectedGroups: []string{"1", "2", "3"},
},
"Complex Groups": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
"groups": []interface{}{
map[string]interface{}{
"groupId": "Admin Group Id",
"roles": []string{"Admin"},
},
12345,
"Just::A::String",
},
},
GroupsClaim: "groups",
ExpectedGroups: []string{
"{\"groupId\":\"Admin Group Id\",\"roles\":[\"Admin\"]}",
"12345",
"Just::A::String",
},
},
"Missing Groups Claim Returns Nil": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
},
GroupsClaim: "groups",
ExpectedGroups: nil,
},
"Non List Groups": {
Claims: map[string]interface{}{
"email": "this@does.not.matter.com",
"groups": "singleton",
},
GroupsClaim: "groups",
ExpectedGroups: []string{"singleton"},
},
}
for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
g := NewWithT(t)
provider := &ProviderData{
Verifier: oidc.NewVerifier(
oidcIssuer,
mockJWKS{},
&oidc.Config{ClientID: oidcClientID},
),
}
provider.GroupsClaim = tc.GroupsClaim
groups := provider.extractGroups(tc.Claims)
if tc.ExpectedGroups != nil {
g.Expect(groups).To(Equal(tc.ExpectedGroups))
} else {
g.Expect(groups).To(BeNil())
}
})
}
}

View File

@ -22,6 +22,14 @@ var (
// code
ErrMissingCode = errors.New("missing code")
// ErrMissingIDToken is returned when an oidc.Token does not contain the
// extra `id_token` field for an IDToken.
ErrMissingIDToken = errors.New("missing id_token")
// ErrMissingOIDCVerifier is returned when a provider didn't set `Verifier`
// but an attempt to call `Verifier.Verify` was about to be made.
ErrMissingOIDCVerifier = errors.New("oidc verifier is not configured")
_ Provider = (*ProviderData)(nil)
)

View File

@ -1,9 +1,13 @@
package providers
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"github.com/bitly/go-simplejson"
"golang.org/x/oauth2"
)
const (
@ -55,3 +59,42 @@ func makeLoginURL(p *ProviderData, redirectURI, state string, extraParams url.Va
a.RawQuery = params.Encode()
return a
}
// getIDToken extracts an IDToken stored in the `Extra` fields of an
// oauth2.Token
func getIDToken(token *oauth2.Token) string {
idToken, ok := token.Extra("id_token").(string)
if !ok {
return ""
}
return idToken
}
// formatGroup coerces an OIDC groups claim into a string
// If it is non-string, marshal it into JSON.
func formatGroup(rawGroup interface{}) (string, error) {
if group, ok := rawGroup.(string); ok {
return group, nil
}
jsonGroup, err := json.Marshal(rawGroup)
if err != nil {
return "", err
}
return string(jsonGroup), nil
}
// coerceArray extracts a field from simplejson.Json that might be a
// singleton or a list and coerces it into a list.
func coerceArray(sj *simplejson.Json, key string) []interface{} {
array, err := sj.Get(key).Array()
if err == nil {
return array
}
single := sj.Get(key).Interface()
if single == nil {
return nil
}
return []interface{}{single}
}

View File

@ -5,9 +5,10 @@ import (
"testing"
. "github.com/onsi/gomega"
"golang.org/x/oauth2"
)
func TestMakeAuhtorizationHeader(t *testing.T) {
func Test_makeAuthorizationHeader(t *testing.T) {
testCases := []struct {
name string
prefix string
@ -64,3 +65,49 @@ func TestMakeAuhtorizationHeader(t *testing.T) {
})
}
}
func Test_getIDToken(t *testing.T) {
const idToken = "eyJfoobar.eyJfoobar.12345asdf"
g := NewWithT(t)
token := &oauth2.Token{}
g.Expect(getIDToken(token)).To(Equal(""))
extraToken := token.WithExtra(map[string]interface{}{
"id_token": idToken,
})
g.Expect(getIDToken(extraToken)).To(Equal(idToken))
}
func Test_formatGroup(t *testing.T) {
testCases := map[string]struct {
rawGroup interface{}
expected string
}{
"String Group": {
rawGroup: "group",
expected: "group",
},
"Numeric Group": {
rawGroup: 123,
expected: "123",
},
"Map Group": {
rawGroup: map[string]string{"id": "1", "name": "Test"},
expected: "{\"id\":\"1\",\"name\":\"Test\"}",
},
"List Group": {
rawGroup: []string{"First", "Second"},
expected: "[\"First\",\"Second\"]",
},
}
for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
g := NewWithT(t)
formattedGroup, err := formatGroup(tc.rawGroup)
g.Expect(err).ToNot(HaveOccurred())
g.Expect(formattedGroup).To(Equal(tc.expected))
})
}
}