feat: add Apple Sign in with Apple OIDC provider.
Signed-off-by: 梁杨峻玮 <lyjw2007@gmail.com>
This commit is contained in:
parent
6a0d821df8
commit
51874e8275
|
|
@ -93,6 +93,8 @@ type Provider struct {
|
|||
OIDCConfig OIDCOptions `yaml:"oidcConfig,omitempty"`
|
||||
// LoginGovConfig holds all configurations for LoginGov provider.
|
||||
LoginGovConfig LoginGovOptions `yaml:"loginGovConfig,omitempty"`
|
||||
// AppleConfig holds all configurations for Apple provider.
|
||||
AppleConfig AppleOptions `yaml:"appleConfig,omitempty"`
|
||||
|
||||
// ID should be a unique identifier for the provider.
|
||||
// This value is required for all providers.
|
||||
|
|
@ -195,6 +197,9 @@ const (
|
|||
|
||||
// SourceHutProvider is the provider type for SourceHut
|
||||
SourceHutProvider ProviderType = "sourcehut"
|
||||
|
||||
// AppleProvider is the provider type for Apple Sign in with Apple
|
||||
AppleProvider ProviderType = "apple"
|
||||
)
|
||||
|
||||
type KeycloakOptions struct {
|
||||
|
|
@ -329,6 +334,17 @@ type LoginGovOptions struct {
|
|||
PubJWKURL string `yaml:"pubjwkURL,omitempty"`
|
||||
}
|
||||
|
||||
type AppleOptions struct {
|
||||
// TeamID is the 10-character Apple Developer Team ID
|
||||
TeamID string `yaml:"teamID,omitempty"`
|
||||
// KeyID is the 10-character identifier for the private key
|
||||
KeyID string `yaml:"keyID,omitempty"`
|
||||
// PrivateKey is the PEM-encoded ES256 private key content (from .p8 file)
|
||||
PrivateKey string `yaml:"privateKey,omitempty"`
|
||||
// PrivateKeyFile is the path to the .p8 private key file
|
||||
PrivateKeyFile string `yaml:"privateKeyFile,omitempty"`
|
||||
}
|
||||
|
||||
// Legacy default providers configuration
|
||||
func providerDefaults() Providers {
|
||||
providers := Providers{
|
||||
|
|
|
|||
|
|
@ -0,0 +1,305 @@
|
|||
package providers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
|
||||
)
|
||||
|
||||
const (
|
||||
appleProviderName = "Apple"
|
||||
appleDefaultScope = "openid email name"
|
||||
|
||||
appleIssuerURL = "https://appleid.apple.com"
|
||||
appleAuthURL = "https://appleid.apple.com/auth/authorize"
|
||||
appleTokenURL = "https://appleid.apple.com/auth/token"
|
||||
appleAudience = "https://appleid.apple.com"
|
||||
)
|
||||
|
||||
var (
|
||||
appleDefaultLoginURL = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "appleid.apple.com",
|
||||
Path: "/auth/authorize",
|
||||
}
|
||||
|
||||
appleDefaultRedeemURL = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "appleid.apple.com",
|
||||
Path: "/auth/token",
|
||||
}
|
||||
)
|
||||
|
||||
// AppleProvider represents the Apple Sign in with Apple OIDC provider
|
||||
type AppleProvider struct {
|
||||
*OIDCProvider
|
||||
|
||||
TeamID string
|
||||
KeyID string
|
||||
PrivateKey *ecdsa.PrivateKey
|
||||
}
|
||||
|
||||
var _ Provider = (*AppleProvider)(nil)
|
||||
|
||||
// NewAppleProvider creates a new AppleProvider
|
||||
func NewAppleProvider(p *ProviderData, appleOpts options.AppleOptions, oidcOpts options.OIDCOptions) (*AppleProvider, error) {
|
||||
p.setProviderDefaults(providerDefaults{
|
||||
name: appleProviderName,
|
||||
loginURL: appleDefaultLoginURL,
|
||||
redeemURL: appleDefaultRedeemURL,
|
||||
profileURL: nil,
|
||||
validateURL: nil,
|
||||
scope: appleDefaultScope,
|
||||
})
|
||||
p.getAuthorizationHeaderFunc = makeOIDCHeader
|
||||
|
||||
oidcProvider := &OIDCProvider{
|
||||
ProviderData: p,
|
||||
SkipNonce: true, // Apple doesn't use nonce in the standard way
|
||||
}
|
||||
|
||||
provider := &AppleProvider{
|
||||
OIDCProvider: oidcProvider,
|
||||
TeamID: appleOpts.TeamID,
|
||||
KeyID: appleOpts.KeyID,
|
||||
}
|
||||
|
||||
if err := provider.configure(appleOpts); err != nil {
|
||||
return nil, fmt.Errorf("could not configure Apple provider: %v", err)
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// configure validates and sets up the Apple provider with the private key
|
||||
func (p *AppleProvider) configure(opts options.AppleOptions) error {
|
||||
if opts.TeamID == "" {
|
||||
return errors.New("apple provider requires teamID")
|
||||
}
|
||||
if opts.KeyID == "" {
|
||||
return errors.New("apple provider requires keyID")
|
||||
}
|
||||
|
||||
// Private key can be supplied via config or file, but not both
|
||||
switch {
|
||||
case opts.PrivateKey != "" && opts.PrivateKeyFile != "":
|
||||
return errors.New("cannot set both privateKey and privateKeyFile options")
|
||||
case opts.PrivateKey == "" && opts.PrivateKeyFile == "":
|
||||
return errors.New("apple provider requires a private key for signing JWTs")
|
||||
case opts.PrivateKey != "":
|
||||
key, err := parseECPrivateKey([]byte(opts.PrivateKey))
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not parse EC private key: %v", err)
|
||||
}
|
||||
p.PrivateKey = key
|
||||
case opts.PrivateKeyFile != "":
|
||||
keyData, err := os.ReadFile(opts.PrivateKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not read private key file %s: %v", opts.PrivateKeyFile, err)
|
||||
}
|
||||
key, err := parseECPrivateKey(keyData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not parse private key from file %s: %v", opts.PrivateKeyFile, err)
|
||||
}
|
||||
p.PrivateKey = key
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseECPrivateKey parses a PEM-encoded EC private key (Apple .p8 format)
|
||||
func parseECPrivateKey(keyData []byte) (*ecdsa.PrivateKey, error) {
|
||||
// Apple .p8 files contain a PEM-encoded PKCS#8 private key
|
||||
block, _ := pem.Decode(keyData)
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode PEM block")
|
||||
}
|
||||
|
||||
// Try PKCS#8 first (Apple's format)
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err == nil {
|
||||
if ecKey, ok := key.(*ecdsa.PrivateKey); ok {
|
||||
return ecKey, nil
|
||||
}
|
||||
return nil, errors.New("key is not an EC private key")
|
||||
}
|
||||
|
||||
// Fall back to EC private key format
|
||||
ecKey, err := x509.ParseECPrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse EC private key: %v", err)
|
||||
}
|
||||
|
||||
return ecKey, nil
|
||||
}
|
||||
|
||||
// generateClientSecret creates a JWT client_secret for Apple token requests
|
||||
// Apple requires the client_secret to be a JWT signed with ES256
|
||||
func (p *AppleProvider) generateClientSecret() (string, error) {
|
||||
now := time.Now()
|
||||
claims := &jwt.RegisteredClaims{
|
||||
Issuer: p.TeamID,
|
||||
Subject: p.ClientID,
|
||||
Audience: jwt.ClaimStrings{appleAudience},
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(5 * time.Minute)), // Short-lived for security
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
|
||||
token.Header["kid"] = p.KeyID
|
||||
|
||||
return token.SignedString(p.PrivateKey)
|
||||
}
|
||||
|
||||
// GetLoginURL returns the Apple authorization URL with required parameters
|
||||
func (p *AppleProvider) GetLoginURL(redirectURI, state, nonce string, extraParams url.Values) string {
|
||||
// Apple requires response_mode=form_post for web clients
|
||||
if extraParams.Get("response_mode") == "" {
|
||||
extraParams.Set("response_mode", "form_post")
|
||||
}
|
||||
return p.OIDCProvider.GetLoginURL(redirectURI, state, nonce, extraParams)
|
||||
}
|
||||
|
||||
// Redeem exchanges the authorization code for tokens
|
||||
func (p *AppleProvider) Redeem(ctx context.Context, redirectURL, code, codeVerifier string) (*sessions.SessionState, error) {
|
||||
if code == "" {
|
||||
return nil, ErrMissingCode
|
||||
}
|
||||
|
||||
clientSecret, err := p.generateClientSecret()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate client secret: %v", err)
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Add("client_id", p.ClientID)
|
||||
params.Add("client_secret", clientSecret)
|
||||
params.Add("code", code)
|
||||
params.Add("grant_type", "authorization_code")
|
||||
params.Add("redirect_uri", redirectURL)
|
||||
if codeVerifier != "" {
|
||||
params.Add("code_verifier", codeVerifier)
|
||||
}
|
||||
|
||||
var jsonResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
err = requests.New(p.RedeemURL.String()).
|
||||
WithContext(ctx).
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
Do().
|
||||
UnmarshalInto(&jsonResponse)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token exchange failed: %v", err)
|
||||
}
|
||||
|
||||
ctx = oidc.ClientContext(ctx, requests.DefaultHTTPClient)
|
||||
|
||||
// Build session from ID token claims
|
||||
ss, err := p.buildSessionFromClaims(jsonResponse.IDToken, jsonResponse.AccessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build session from claims: %v", err)
|
||||
}
|
||||
|
||||
ss.AccessToken = jsonResponse.AccessToken
|
||||
ss.RefreshToken = jsonResponse.RefreshToken
|
||||
ss.IDToken = jsonResponse.IDToken
|
||||
|
||||
ss.CreatedAtNow()
|
||||
ss.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second)
|
||||
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
// RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
|
||||
func (p *AppleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || s.RefreshToken == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
clientSecret, err := p.generateClientSecret()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to generate client secret: %v", err)
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Add("client_id", p.ClientID)
|
||||
params.Add("client_secret", clientSecret)
|
||||
params.Add("grant_type", "refresh_token")
|
||||
params.Add("refresh_token", s.RefreshToken)
|
||||
|
||||
var jsonResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
err = requests.New(p.RedeemURL.String()).
|
||||
WithContext(ctx).
|
||||
WithMethod("POST").
|
||||
WithBody(bytes.NewBufferString(params.Encode())).
|
||||
SetHeader("Content-Type", "application/x-www-form-urlencoded").
|
||||
Do().
|
||||
UnmarshalInto(&jsonResponse)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("refresh token failed: %v", err)
|
||||
}
|
||||
|
||||
// Update session with new tokens
|
||||
if jsonResponse.IDToken != "" {
|
||||
ctx = oidc.ClientContext(ctx, requests.DefaultHTTPClient)
|
||||
newSession, err := p.buildSessionFromClaims(jsonResponse.IDToken, jsonResponse.AccessToken)
|
||||
if err == nil {
|
||||
s.Email = newSession.Email
|
||||
s.User = newSession.User
|
||||
s.Groups = newSession.Groups
|
||||
s.PreferredUsername = newSession.PreferredUsername
|
||||
}
|
||||
s.IDToken = jsonResponse.IDToken
|
||||
}
|
||||
|
||||
s.AccessToken = jsonResponse.AccessToken
|
||||
if jsonResponse.RefreshToken != "" {
|
||||
s.RefreshToken = jsonResponse.RefreshToken
|
||||
}
|
||||
|
||||
s.CreatedAtNow()
|
||||
s.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ValidateSession validates the session's ID token
|
||||
func (p *AppleProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
|
||||
ctx = oidc.ClientContext(ctx, requests.DefaultHTTPClient)
|
||||
|
||||
if s.IDToken != "" && p.Verifier != nil {
|
||||
if _, err := p.Verifier.Verify(ctx, s.IDToken); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken))
|
||||
}
|
||||
|
|
@ -0,0 +1,282 @@
|
|||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func newAppleServer(body []byte) (*url.URL, *httptest.Server) {
|
||||
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.Write(body)
|
||||
}))
|
||||
u, _ := url.Parse(s.URL)
|
||||
return u, s
|
||||
}
|
||||
|
||||
func generateTestECPrivateKey() (*ecdsa.PrivateKey, []byte, error) {
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
keyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pemBlock := &pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: keyBytes,
|
||||
}
|
||||
|
||||
return privateKey, pem.EncodeToMemory(pemBlock), nil
|
||||
}
|
||||
|
||||
func newAppleProvider() (*AppleProvider, *ecdsa.PrivateKey, error) {
|
||||
privKey, privKeyPEM, err := generateTestECPrivateKey()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
p, err := NewAppleProvider(
|
||||
&ProviderData{
|
||||
ProviderName: "",
|
||||
LoginURL: &url.URL{},
|
||||
RedeemURL: &url.URL{},
|
||||
ProfileURL: &url.URL{},
|
||||
ValidateURL: &url.URL{},
|
||||
Scope: "",
|
||||
ClientID: "com.example.client",
|
||||
},
|
||||
options.AppleOptions{
|
||||
TeamID: "TEAM123456",
|
||||
KeyID: "KEY1234567",
|
||||
PrivateKey: string(privKeyPEM),
|
||||
},
|
||||
options.OIDCOptions{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return p, privKey, nil
|
||||
}
|
||||
|
||||
func TestNewAppleProvider(t *testing.T) {
|
||||
g := NewWithT(t)
|
||||
|
||||
_, privKeyPEM, err := generateTestECPrivateKey()
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Test that defaults are set when calling for a new provider
|
||||
provider, err := NewAppleProvider(
|
||||
&ProviderData{
|
||||
ClientID: "com.example.client",
|
||||
},
|
||||
options.AppleOptions{
|
||||
TeamID: "TEAM123456",
|
||||
KeyID: "KEY1234567",
|
||||
PrivateKey: string(privKeyPEM),
|
||||
},
|
||||
options.OIDCOptions{},
|
||||
)
|
||||
g.Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
providerData := provider.Data()
|
||||
g.Expect(providerData.ProviderName).To(Equal("Apple"))
|
||||
g.Expect(providerData.LoginURL.String()).To(Equal("https://appleid.apple.com/auth/authorize"))
|
||||
g.Expect(providerData.RedeemURL.String()).To(Equal("https://appleid.apple.com/auth/token"))
|
||||
g.Expect(providerData.Scope).To(Equal("openid email name"))
|
||||
}
|
||||
|
||||
func TestAppleProviderMissingTeamID(t *testing.T) {
|
||||
_, privKeyPEM, err := generateTestECPrivateKey()
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = NewAppleProvider(
|
||||
&ProviderData{},
|
||||
options.AppleOptions{
|
||||
KeyID: "KEY1234567",
|
||||
PrivateKey: string(privKeyPEM),
|
||||
},
|
||||
options.OIDCOptions{},
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "teamID")
|
||||
}
|
||||
|
||||
func TestAppleProviderMissingKeyID(t *testing.T) {
|
||||
_, privKeyPEM, err := generateTestECPrivateKey()
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = NewAppleProvider(
|
||||
&ProviderData{},
|
||||
options.AppleOptions{
|
||||
TeamID: "TEAM123456",
|
||||
PrivateKey: string(privKeyPEM),
|
||||
},
|
||||
options.OIDCOptions{},
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "keyID")
|
||||
}
|
||||
|
||||
func TestAppleProviderMissingPrivateKey(t *testing.T) {
|
||||
_, err := NewAppleProvider(
|
||||
&ProviderData{},
|
||||
options.AppleOptions{
|
||||
TeamID: "TEAM123456",
|
||||
KeyID: "KEY1234567",
|
||||
},
|
||||
options.OIDCOptions{},
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "private key")
|
||||
}
|
||||
|
||||
func TestAppleProviderBothPrivateKeyOptions(t *testing.T) {
|
||||
_, privKeyPEM, err := generateTestECPrivateKey()
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = NewAppleProvider(
|
||||
&ProviderData{},
|
||||
options.AppleOptions{
|
||||
TeamID: "TEAM123456",
|
||||
KeyID: "KEY1234567",
|
||||
PrivateKey: string(privKeyPEM),
|
||||
PrivateKeyFile: "/path/to/key.p8",
|
||||
},
|
||||
options.OIDCOptions{},
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "cannot set both")
|
||||
}
|
||||
|
||||
func TestAppleProviderGenerateClientSecret(t *testing.T) {
|
||||
p, privKey, err := newAppleProvider()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, p)
|
||||
|
||||
secret, err := p.generateClientSecret()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, secret)
|
||||
|
||||
// Verify the JWT
|
||||
token, err := jwt.Parse(secret, func(token *jwt.Token) (interface{}, error) {
|
||||
return &privKey.PublicKey, nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, token.Valid)
|
||||
|
||||
// Verify claims
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "TEAM123456", claims["iss"])
|
||||
assert.Equal(t, "com.example.client", claims["sub"])
|
||||
|
||||
// Verify header
|
||||
assert.Equal(t, "ES256", token.Method.Alg())
|
||||
assert.Equal(t, "KEY1234567", token.Header["kid"])
|
||||
}
|
||||
|
||||
func TestAppleProviderGetLoginURL(t *testing.T) {
|
||||
p, _, err := newAppleProvider()
|
||||
assert.NoError(t, err)
|
||||
|
||||
result := p.GetLoginURL("https://example.com/callback", "state123", "nonce123", url.Values{})
|
||||
assert.Contains(t, result, "response_mode=form_post")
|
||||
assert.Contains(t, result, "state=state123")
|
||||
assert.Contains(t, result, "redirect_uri=")
|
||||
}
|
||||
|
||||
func TestAppleProviderRedeem(t *testing.T) {
|
||||
p, _, err := newAppleProvider()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, p)
|
||||
|
||||
// Create a mock ID token
|
||||
expiresIn := int64(3600)
|
||||
idTokenClaims := jwt.MapClaims{
|
||||
"iss": "https://appleid.apple.com",
|
||||
"sub": "user123",
|
||||
"aud": "com.example.client",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"email": "user@example.com",
|
||||
}
|
||||
|
||||
// Sign with test key for mock purposes
|
||||
privKey, _, _ := generateTestECPrivateKey()
|
||||
idToken := jwt.NewWithClaims(jwt.SigningMethodES256, idTokenClaims)
|
||||
signedIDToken, err := idToken.SignedString(privKey)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Set up mock server response
|
||||
body, err := json.Marshal(map[string]interface{}{
|
||||
"access_token": "mock_access_token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": expiresIn,
|
||||
"refresh_token": "mock_refresh_token",
|
||||
"id_token": signedIDToken,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var server *httptest.Server
|
||||
p.RedeemURL, server = newAppleServer(body)
|
||||
defer server.Close()
|
||||
|
||||
session, err := p.Redeem(context.Background(), "https://example.com/callback", "code123", "")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, session)
|
||||
assert.Equal(t, "mock_access_token", session.AccessToken)
|
||||
assert.Equal(t, "mock_refresh_token", session.RefreshToken)
|
||||
assert.Equal(t, signedIDToken, session.IDToken)
|
||||
}
|
||||
|
||||
func TestAppleProviderRefreshSession(t *testing.T) {
|
||||
p, _, err := newAppleProvider()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, p)
|
||||
|
||||
expiresIn := int64(3600)
|
||||
|
||||
// Set up mock server response
|
||||
body, err := json.Marshal(map[string]interface{}{
|
||||
"access_token": "new_access_token",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": expiresIn,
|
||||
"refresh_token": "new_refresh_token",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var server *httptest.Server
|
||||
p.RedeemURL, server = newAppleServer(body)
|
||||
defer server.Close()
|
||||
|
||||
session := &sessions.SessionState{
|
||||
RefreshToken: "old_refresh_token",
|
||||
}
|
||||
|
||||
refreshed, err := p.RefreshSession(context.Background(), session)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, refreshed)
|
||||
assert.Equal(t, "new_access_token", session.AccessToken)
|
||||
assert.Equal(t, "new_refresh_token", session.RefreshToken)
|
||||
}
|
||||
|
|
@ -70,6 +70,8 @@ func NewProvider(providerConfig options.Provider) (Provider, error) {
|
|||
return NewNextcloudProvider(providerData), nil
|
||||
case options.OIDCProvider:
|
||||
return NewOIDCProvider(providerData, providerConfig.OIDCConfig), nil
|
||||
case options.AppleProvider:
|
||||
return NewAppleProvider(providerData, providerConfig.AppleConfig, providerConfig.OIDCConfig)
|
||||
case options.SourceHutProvider:
|
||||
return NewSourceHutProvider(providerData), nil
|
||||
default:
|
||||
|
|
@ -192,7 +194,7 @@ func providerRequiresOIDCProviderVerifier(providerType options.ProviderType) (bo
|
|||
options.NextCloudProvider, options.SourceHutProvider:
|
||||
return false, nil
|
||||
case options.OIDCProvider, options.ADFSProvider, options.AzureProvider, options.CidaasProvider,
|
||||
options.GitLabProvider, options.KeycloakOIDCProvider, options.MicrosoftEntraIDProvider:
|
||||
options.GitLabProvider, options.KeycloakOIDCProvider, options.MicrosoftEntraIDProvider, options.AppleProvider:
|
||||
return true, nil
|
||||
default:
|
||||
return false, fmt.Errorf("unknown provider type: %s", providerType)
|
||||
|
|
|
|||
Loading…
Reference in New Issue