113 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			113 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Go
		
	
	
	
| package header
 | |
| 
 | |
| import (
 | |
| 	"encoding/base64"
 | |
| 	"fmt"
 | |
| 	"net/http"
 | |
| 
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options/util"
 | |
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
 | |
| )
 | |
| 
 | |
| type Injector interface {
 | |
| 	Inject(http.Header, *sessionsapi.SessionState)
 | |
| }
 | |
| 
 | |
| type injector struct {
 | |
| 	valueInjectors []valueInjector
 | |
| }
 | |
| 
 | |
| func (i injector) Inject(header http.Header, session *sessionsapi.SessionState) {
 | |
| 	for _, injector := range i.valueInjectors {
 | |
| 		injector.inject(header, session)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func NewInjector(headers []options.Header) (Injector, error) {
 | |
| 	injectors := []valueInjector{}
 | |
| 	for _, header := range headers {
 | |
| 		for _, value := range header.Values {
 | |
| 			injector, err := newValueinjector(header.Name, value)
 | |
| 			if err != nil {
 | |
| 				return nil, fmt.Errorf("error building injector for header %q: %v", header.Name, err)
 | |
| 			}
 | |
| 			injectors = append(injectors, injector)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return &injector{valueInjectors: injectors}, nil
 | |
| }
 | |
| 
 | |
| type valueInjector interface {
 | |
| 	inject(http.Header, *sessionsapi.SessionState)
 | |
| }
 | |
| 
 | |
| func newValueinjector(name string, value options.HeaderValue) (valueInjector, error) {
 | |
| 	switch {
 | |
| 	case value.SecretSource != nil && value.ClaimSource == nil:
 | |
| 		return newSecretInjector(name, value.SecretSource)
 | |
| 	case value.SecretSource == nil && value.ClaimSource != nil:
 | |
| 		return newClaimInjector(name, value.ClaimSource)
 | |
| 	default:
 | |
| 		return nil, fmt.Errorf("header %q value has multiple entries: only one entry per value is allowed", name)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type injectorFunc struct {
 | |
| 	injectFunc func(http.Header, *sessionsapi.SessionState)
 | |
| }
 | |
| 
 | |
| func (i *injectorFunc) inject(header http.Header, session *sessionsapi.SessionState) {
 | |
| 	i.injectFunc(header, session)
 | |
| }
 | |
| 
 | |
| func newInjectorFunc(injectFunc func(header http.Header, session *sessionsapi.SessionState)) valueInjector {
 | |
| 	return &injectorFunc{injectFunc: injectFunc}
 | |
| }
 | |
| 
 | |
| func newSecretInjector(name string, source *options.SecretSource) (valueInjector, error) {
 | |
| 	value, err := util.GetSecretValue(source)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("error getting secret value: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
 | |
| 		header.Add(name, string(value))
 | |
| 	}), nil
 | |
| }
 | |
| 
 | |
| func newClaimInjector(name string, source *options.ClaimSource) (valueInjector, error) {
 | |
| 	switch {
 | |
| 	case source.BasicAuthPassword != nil:
 | |
| 		password, err := util.GetSecretValue(source.BasicAuthPassword)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("error loading basicAuthPassword: %v", err)
 | |
| 		}
 | |
| 		return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
 | |
| 			claim := session.GetClaim(source.Claim)
 | |
| 			if claim == "" {
 | |
| 				return
 | |
| 			}
 | |
| 			auth := claim + ":" + string(password)
 | |
| 			header.Add(name, "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
 | |
| 		}), nil
 | |
| 	case source.Prefix != "":
 | |
| 		return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
 | |
| 			claim := session.GetClaim(source.Claim)
 | |
| 			if claim == "" {
 | |
| 				return
 | |
| 			}
 | |
| 			header.Add(name, source.Prefix+claim)
 | |
| 		}), nil
 | |
| 	default:
 | |
| 		return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) {
 | |
| 			claim := session.GetClaim(source.Claim)
 | |
| 			if claim == "" {
 | |
| 				return
 | |
| 			}
 | |
| 			header.Add(name, claim)
 | |
| 		}), nil
 | |
| 	}
 | |
| }
 |