171 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			171 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Go
		
	
	
	
| package providers
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"encoding/base64"
 | |
| 	"encoding/json"
 | |
| 	"fmt"
 | |
| 	"strings"
 | |
| 
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
 | |
| )
 | |
| 
 | |
| const keycloakOIDCProviderName = "Keycloak OIDC"
 | |
| 
 | |
| // KeycloakOIDCProvider creates a Keycloak provider based on OIDCProvider
 | |
| type KeycloakOIDCProvider struct {
 | |
| 	*OIDCProvider
 | |
| }
 | |
| 
 | |
| // NewKeycloakOIDCProvider makes a KeycloakOIDCProvider using the ProviderData
 | |
| func NewKeycloakOIDCProvider(p *ProviderData, opts options.Provider) *KeycloakOIDCProvider {
 | |
| 	p.setProviderDefaults(providerDefaults{
 | |
| 		name: keycloakOIDCProviderName,
 | |
| 	})
 | |
| 
 | |
| 	provider := &KeycloakOIDCProvider{
 | |
| 		OIDCProvider: NewOIDCProvider(p, opts.OIDCConfig),
 | |
| 	}
 | |
| 
 | |
| 	provider.addAllowedRoles(opts.KeycloakConfig.Roles)
 | |
| 	return provider
 | |
| }
 | |
| 
 | |
| var _ Provider = (*KeycloakOIDCProvider)(nil)
 | |
| 
 | |
| // addAllowedRoles sets Keycloak roles that are authorized.
 | |
| // Assumes `SetAllowedGroups` is already called on groups and appends to that
 | |
| // with `role:` prefixed roles.
 | |
| func (p *KeycloakOIDCProvider) addAllowedRoles(roles []string) {
 | |
| 	if p.AllowedGroups == nil {
 | |
| 		p.AllowedGroups = make(map[string]struct{})
 | |
| 	}
 | |
| 	for _, role := range roles {
 | |
| 		p.AllowedGroups[formatRole(role)] = struct{}{}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // CreateSessionFromToken converts Bearer IDTokens into sessions
 | |
| func (p *KeycloakOIDCProvider) CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) {
 | |
| 	ss, err := p.OIDCProvider.CreateSessionFromToken(ctx, token)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("could not create session from token: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// Extract custom keycloak roles and enrich session
 | |
| 	if err := p.extractRoles(ss); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return ss, nil
 | |
| }
 | |
| 
 | |
| // EnrichSession is called after Redeem to allow providers to enrich session fields
 | |
| // such as User, Email, Groups with provider specific API calls.
 | |
| func (p *KeycloakOIDCProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error {
 | |
| 	err := p.OIDCProvider.EnrichSession(ctx, s)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("could not enrich oidc session: %v", err)
 | |
| 	}
 | |
| 	return p.extractRoles(s)
 | |
| }
 | |
| 
 | |
| // RefreshSession adds role extraction logic to the refresh flow
 | |
| func (p *KeycloakOIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
 | |
| 	refreshed, err := p.OIDCProvider.RefreshSession(ctx, s)
 | |
| 
 | |
| 	// Refresh could have failed or there was not session to refresh (with no error raised)
 | |
| 	if err != nil || !refreshed {
 | |
| 		return refreshed, err
 | |
| 	}
 | |
| 
 | |
| 	return true, p.extractRoles(s)
 | |
| }
 | |
| 
 | |
| func (p *KeycloakOIDCProvider) extractRoles(s *sessions.SessionState) error {
 | |
| 	claims, err := p.getAccessClaims(s)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	var roles []string
 | |
| 	roles = append(roles, claims.RealmAccess.Roles...)
 | |
| 	roles = append(roles, getClientRoles(claims)...)
 | |
| 
 | |
| 	// Add to groups list with `role:` prefix to distinguish from groups
 | |
| 	for _, role := range roles {
 | |
| 		s.Groups = append(s.Groups, formatRole(role))
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| type realmAccess struct {
 | |
| 	Roles []string `json:"roles"`
 | |
| }
 | |
| 
 | |
| type accessClaims struct {
 | |
| 	RealmAccess    realmAccess            `json:"realm_access"`
 | |
| 	ResourceAccess map[string]interface{} `json:"resource_access"`
 | |
| }
 | |
| 
 | |
| func (p *KeycloakOIDCProvider) getAccessClaims(s *sessions.SessionState) (*accessClaims, error) {
 | |
| 	parts := strings.Split(s.AccessToken, ".")
 | |
| 	if len(parts) < 2 {
 | |
| 		return nil, fmt.Errorf("malformed access token, expected 3 parts got %d", len(parts))
 | |
| 	}
 | |
| 
 | |
| 	payload, err := base64.RawURLEncoding.DecodeString(parts[1])
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("malformed access token, couldn't extract jwt payload: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	var claims accessClaims
 | |
| 	if err := json.Unmarshal(payload, &claims); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return &claims, nil
 | |
| }
 | |
| 
 | |
| // getClientRoles extracts client roles from the `resource_access` claim with
 | |
| // the format `client:role`.
 | |
| //
 | |
| // ResourceAccess format:
 | |
| //
 | |
| //	"resource_access": {
 | |
| //	  "clientA": {
 | |
| //	    "roles": [
 | |
| //	      "roleA"
 | |
| //	    ]
 | |
| //	  },
 | |
| //	  "clientB": {
 | |
| //	    "roles": [
 | |
| //	      "roleA",
 | |
| //	      "roleB",
 | |
| //	      "roleC"
 | |
| //	    ]
 | |
| //	  }
 | |
| //	}
 | |
| func getClientRoles(claims *accessClaims) []string {
 | |
| 	var clientRoles []string
 | |
| 	for clientName, access := range claims.ResourceAccess {
 | |
| 		accessMap, ok := access.(map[string]interface{})
 | |
| 		if !ok {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		var roles interface{}
 | |
| 		if roles, ok = accessMap["roles"]; !ok {
 | |
| 			continue
 | |
| 		}
 | |
| 		for _, role := range roles.([]interface{}) {
 | |
| 			clientRoles = append(clientRoles, fmt.Sprintf("%s:%s", clientName, role))
 | |
| 		}
 | |
| 	}
 | |
| 	return clientRoles
 | |
| }
 | |
| 
 | |
| func formatRole(role string) string {
 | |
| 	return fmt.Sprintf("role:%s", role)
 | |
| }
 |