Make claims list of strings
This commit is contained in:
		
							parent
							
								
									d88d97b440
								
							
						
					
					
						commit
						a27d71b692
					
				|  | @ -8,7 +8,6 @@ import ( | |||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 	"unicode/utf8" | ||||
| 
 | ||||
|  | @ -70,31 +69,33 @@ func (s *SessionState) String() string { | |||
| 	return o + "}" | ||||
| } | ||||
| 
 | ||||
| func (s *SessionState) GetClaim(claim string) string { | ||||
| func (s *SessionState) GetClaim(claim string) []string { | ||||
| 	if s == nil { | ||||
| 		return "" | ||||
| 		return []string{} | ||||
| 	} | ||||
| 	switch claim { | ||||
| 	case "access_token": | ||||
| 		return s.AccessToken | ||||
| 		return []string{s.AccessToken} | ||||
| 	case "id_token": | ||||
| 		return s.IDToken | ||||
| 		return []string{s.IDToken} | ||||
| 	case "created_at": | ||||
| 		return s.CreatedAt.String() | ||||
| 		return []string{s.CreatedAt.String()} | ||||
| 	case "expires_on": | ||||
| 		return s.ExpiresOn.String() | ||||
| 		return []string{s.ExpiresOn.String()} | ||||
| 	case "refresh_token": | ||||
| 		return s.RefreshToken | ||||
| 		return []string{s.RefreshToken} | ||||
| 	case "email": | ||||
| 		return s.Email | ||||
| 		return []string{s.Email} | ||||
| 	case "user": | ||||
| 		return s.User | ||||
| 		return []string{s.User} | ||||
| 	case "groups": | ||||
| 		return strings.Join(s.Groups, ",") | ||||
| 		groups := make([]string, len(s.Groups)) | ||||
| 		copy(groups, s.Groups) | ||||
| 		return groups | ||||
| 	case "preferred_username": | ||||
| 		return s.PreferredUsername | ||||
| 		return []string{s.PreferredUsername} | ||||
| 	default: | ||||
| 		return "" | ||||
| 		return []string{} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -85,28 +85,34 @@ func newClaimInjector(name string, source *options.ClaimSource) (valueInjector, | |||
| 			return nil, fmt.Errorf("error loading basicAuthPassword: %v", err) | ||||
| 		} | ||||
| 		return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) { | ||||
| 			claim := session.GetClaim(source.Claim) | ||||
| 			if claim == "" { | ||||
| 				return | ||||
| 			claimValues := session.GetClaim(source.Claim) | ||||
| 			for _, claim := range claimValues { | ||||
| 				if claim == "" { | ||||
| 					continue | ||||
| 				} | ||||
| 				auth := claim + ":" + string(password) | ||||
| 				header.Add(name, "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) | ||||
| 			} | ||||
| 			auth := claim + ":" + string(password) | ||||
| 			header.Add(name, "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) | ||||
| 		}), nil | ||||
| 	case source.Prefix != "": | ||||
| 		return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) { | ||||
| 			claim := session.GetClaim(source.Claim) | ||||
| 			if claim == "" { | ||||
| 				return | ||||
| 			claimValues := session.GetClaim(source.Claim) | ||||
| 			for _, claim := range claimValues { | ||||
| 				if claim == "" { | ||||
| 					continue | ||||
| 				} | ||||
| 				header.Add(name, source.Prefix+claim) | ||||
| 			} | ||||
| 			header.Add(name, source.Prefix+claim) | ||||
| 		}), nil | ||||
| 	default: | ||||
| 		return newInjectorFunc(func(header http.Header, session *sessionsapi.SessionState) { | ||||
| 			claim := session.GetClaim(source.Claim) | ||||
| 			if claim == "" { | ||||
| 				return | ||||
| 			claimValues := session.GetClaim(source.Claim) | ||||
| 			for _, claim := range claimValues { | ||||
| 				if claim == "" { | ||||
| 					continue | ||||
| 				} | ||||
| 				header.Add(name, claim) | ||||
| 			} | ||||
| 			header.Add(name, claim) | ||||
| 		}), nil | ||||
| 	} | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue