413 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			413 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
/*
 | 
						|
 *
 | 
						|
 * Copyright 2021 Google LLC
 | 
						|
 *
 | 
						|
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
 * you may not use this file except in compliance with the License.
 | 
						|
 * You may obtain a copy of the License at
 | 
						|
 *
 | 
						|
 *     https://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 *
 | 
						|
 * Unless required by applicable law or agreed to in writing, software
 | 
						|
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
 * See the License for the specific language governing permissions and
 | 
						|
 * limitations under the License.
 | 
						|
 *
 | 
						|
 */
 | 
						|
 | 
						|
// Package s2a provides the S2A transport credentials used by a gRPC
 | 
						|
// application.
 | 
						|
package s2a
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"crypto/tls"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"net"
 | 
						|
	"sync"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/golang/protobuf/proto"
 | 
						|
	"github.com/google/s2a-go/fallback"
 | 
						|
	"github.com/google/s2a-go/internal/handshaker"
 | 
						|
	"github.com/google/s2a-go/internal/handshaker/service"
 | 
						|
	"github.com/google/s2a-go/internal/tokenmanager"
 | 
						|
	"github.com/google/s2a-go/internal/v2"
 | 
						|
	"google.golang.org/grpc/credentials"
 | 
						|
	"google.golang.org/grpc/grpclog"
 | 
						|
 | 
						|
	commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
 | 
						|
	s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
 | 
						|
)
 | 
						|
 | 
						|
const (
 | 
						|
	s2aSecurityProtocol = "tls"
 | 
						|
	// defaultTimeout specifies the default server handshake timeout.
 | 
						|
	defaultTimeout = 30.0 * time.Second
 | 
						|
)
 | 
						|
 | 
						|
// s2aTransportCreds are the transport credentials required for establishing
 | 
						|
// a secure connection using the S2A. They implement the
 | 
						|
// credentials.TransportCredentials interface.
 | 
						|
type s2aTransportCreds struct {
 | 
						|
	info          *credentials.ProtocolInfo
 | 
						|
	minTLSVersion commonpb.TLSVersion
 | 
						|
	maxTLSVersion commonpb.TLSVersion
 | 
						|
	// tlsCiphersuites contains the ciphersuites used in the S2A connection.
 | 
						|
	// Note that these are currently unconfigurable.
 | 
						|
	tlsCiphersuites []commonpb.Ciphersuite
 | 
						|
	// localIdentity should only be used by the client.
 | 
						|
	localIdentity *commonpb.Identity
 | 
						|
	// localIdentities should only be used by the server.
 | 
						|
	localIdentities []*commonpb.Identity
 | 
						|
	// targetIdentities should only be used by the client.
 | 
						|
	targetIdentities            []*commonpb.Identity
 | 
						|
	isClient                    bool
 | 
						|
	s2aAddr                     string
 | 
						|
	ensureProcessSessionTickets *sync.WaitGroup
 | 
						|
}
 | 
						|
 | 
						|
// NewClientCreds returns a client-side transport credentials object that uses
 | 
						|
// the S2A to establish a secure connection with a server.
 | 
						|
func NewClientCreds(opts *ClientOptions) (credentials.TransportCredentials, error) {
 | 
						|
	if opts == nil {
 | 
						|
		return nil, errors.New("nil client options")
 | 
						|
	}
 | 
						|
	var targetIdentities []*commonpb.Identity
 | 
						|
	for _, targetIdentity := range opts.TargetIdentities {
 | 
						|
		protoTargetIdentity, err := toProtoIdentity(targetIdentity)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
		targetIdentities = append(targetIdentities, protoTargetIdentity)
 | 
						|
	}
 | 
						|
	localIdentity, err := toProtoIdentity(opts.LocalIdentity)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	if opts.EnableLegacyMode {
 | 
						|
		return &s2aTransportCreds{
 | 
						|
			info: &credentials.ProtocolInfo{
 | 
						|
				SecurityProtocol: s2aSecurityProtocol,
 | 
						|
			},
 | 
						|
			minTLSVersion: commonpb.TLSVersion_TLS1_3,
 | 
						|
			maxTLSVersion: commonpb.TLSVersion_TLS1_3,
 | 
						|
			tlsCiphersuites: []commonpb.Ciphersuite{
 | 
						|
				commonpb.Ciphersuite_AES_128_GCM_SHA256,
 | 
						|
				commonpb.Ciphersuite_AES_256_GCM_SHA384,
 | 
						|
				commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
 | 
						|
			},
 | 
						|
			localIdentity:               localIdentity,
 | 
						|
			targetIdentities:            targetIdentities,
 | 
						|
			isClient:                    true,
 | 
						|
			s2aAddr:                     opts.S2AAddress,
 | 
						|
			ensureProcessSessionTickets: opts.EnsureProcessSessionTickets,
 | 
						|
		}, nil
 | 
						|
	}
 | 
						|
	verificationMode := getVerificationMode(opts.VerificationMode)
 | 
						|
	var fallbackFunc fallback.ClientHandshake
 | 
						|
	if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackClientHandshakeFunc != nil {
 | 
						|
		fallbackFunc = opts.FallbackOpts.FallbackClientHandshakeFunc
 | 
						|
	}
 | 
						|
	return v2.NewClientCreds(opts.S2AAddress, localIdentity, verificationMode, fallbackFunc, opts.getS2AStream, opts.serverAuthorizationPolicy)
 | 
						|
}
 | 
						|
 | 
						|
// NewServerCreds returns a server-side transport credentials object that uses
 | 
						|
// the S2A to establish a secure connection with a client.
 | 
						|
func NewServerCreds(opts *ServerOptions) (credentials.TransportCredentials, error) {
 | 
						|
	if opts == nil {
 | 
						|
		return nil, errors.New("nil server options")
 | 
						|
	}
 | 
						|
	var localIdentities []*commonpb.Identity
 | 
						|
	for _, localIdentity := range opts.LocalIdentities {
 | 
						|
		protoLocalIdentity, err := toProtoIdentity(localIdentity)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
		localIdentities = append(localIdentities, protoLocalIdentity)
 | 
						|
	}
 | 
						|
	if opts.EnableLegacyMode {
 | 
						|
		return &s2aTransportCreds{
 | 
						|
			info: &credentials.ProtocolInfo{
 | 
						|
				SecurityProtocol: s2aSecurityProtocol,
 | 
						|
			},
 | 
						|
			minTLSVersion: commonpb.TLSVersion_TLS1_3,
 | 
						|
			maxTLSVersion: commonpb.TLSVersion_TLS1_3,
 | 
						|
			tlsCiphersuites: []commonpb.Ciphersuite{
 | 
						|
				commonpb.Ciphersuite_AES_128_GCM_SHA256,
 | 
						|
				commonpb.Ciphersuite_AES_256_GCM_SHA384,
 | 
						|
				commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
 | 
						|
			},
 | 
						|
			localIdentities: localIdentities,
 | 
						|
			isClient:        false,
 | 
						|
			s2aAddr:         opts.S2AAddress,
 | 
						|
		}, nil
 | 
						|
	}
 | 
						|
	verificationMode := getVerificationMode(opts.VerificationMode)
 | 
						|
	return v2.NewServerCreds(opts.S2AAddress, localIdentities, verificationMode, opts.getS2AStream)
 | 
						|
}
 | 
						|
 | 
						|
// ClientHandshake initiates a client-side TLS handshake using the S2A.
 | 
						|
func (c *s2aTransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
 | 
						|
	if !c.isClient {
 | 
						|
		return nil, nil, errors.New("client handshake called using server transport credentials")
 | 
						|
	}
 | 
						|
 | 
						|
	// Connect to the S2A.
 | 
						|
	hsConn, err := service.Dial(c.s2aAddr)
 | 
						|
	if err != nil {
 | 
						|
		grpclog.Infof("Failed to connect to S2A: %v", err)
 | 
						|
		return nil, nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	var cancel context.CancelFunc
 | 
						|
	ctx, cancel = context.WithCancel(ctx)
 | 
						|
	defer cancel()
 | 
						|
 | 
						|
	opts := &handshaker.ClientHandshakerOptions{
 | 
						|
		MinTLSVersion:               c.minTLSVersion,
 | 
						|
		MaxTLSVersion:               c.maxTLSVersion,
 | 
						|
		TLSCiphersuites:             c.tlsCiphersuites,
 | 
						|
		TargetIdentities:            c.targetIdentities,
 | 
						|
		LocalIdentity:               c.localIdentity,
 | 
						|
		TargetName:                  serverAuthority,
 | 
						|
		EnsureProcessSessionTickets: c.ensureProcessSessionTickets,
 | 
						|
	}
 | 
						|
	chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
 | 
						|
	if err != nil {
 | 
						|
		grpclog.Infof("Call to handshaker.NewClientHandshaker failed: %v", err)
 | 
						|
		return nil, nil, err
 | 
						|
	}
 | 
						|
	defer func() {
 | 
						|
		if err != nil {
 | 
						|
			if closeErr := chs.Close(); closeErr != nil {
 | 
						|
				grpclog.Infof("Close failed unexpectedly: %v", err)
 | 
						|
				err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}()
 | 
						|
 | 
						|
	secConn, authInfo, err := chs.ClientHandshake(context.Background())
 | 
						|
	if err != nil {
 | 
						|
		grpclog.Infof("Handshake failed: %v", err)
 | 
						|
		return nil, nil, err
 | 
						|
	}
 | 
						|
	return secConn, authInfo, nil
 | 
						|
}
 | 
						|
 | 
						|
// ServerHandshake initiates a server-side TLS handshake using the S2A.
 | 
						|
func (c *s2aTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
 | 
						|
	if c.isClient {
 | 
						|
		return nil, nil, errors.New("server handshake called using client transport credentials")
 | 
						|
	}
 | 
						|
 | 
						|
	// Connect to the S2A.
 | 
						|
	hsConn, err := service.Dial(c.s2aAddr)
 | 
						|
	if err != nil {
 | 
						|
		grpclog.Infof("Failed to connect to S2A: %v", err)
 | 
						|
		return nil, nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
 | 
						|
	defer cancel()
 | 
						|
 | 
						|
	opts := &handshaker.ServerHandshakerOptions{
 | 
						|
		MinTLSVersion:   c.minTLSVersion,
 | 
						|
		MaxTLSVersion:   c.maxTLSVersion,
 | 
						|
		TLSCiphersuites: c.tlsCiphersuites,
 | 
						|
		LocalIdentities: c.localIdentities,
 | 
						|
	}
 | 
						|
	shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
 | 
						|
	if err != nil {
 | 
						|
		grpclog.Infof("Call to handshaker.NewServerHandshaker failed: %v", err)
 | 
						|
		return nil, nil, err
 | 
						|
	}
 | 
						|
	defer func() {
 | 
						|
		if err != nil {
 | 
						|
			if closeErr := shs.Close(); closeErr != nil {
 | 
						|
				grpclog.Infof("Close failed unexpectedly: %v", err)
 | 
						|
				err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}()
 | 
						|
 | 
						|
	secConn, authInfo, err := shs.ServerHandshake(context.Background())
 | 
						|
	if err != nil {
 | 
						|
		grpclog.Infof("Handshake failed: %v", err)
 | 
						|
		return nil, nil, err
 | 
						|
	}
 | 
						|
	return secConn, authInfo, nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *s2aTransportCreds) Info() credentials.ProtocolInfo {
 | 
						|
	return *c.info
 | 
						|
}
 | 
						|
 | 
						|
func (c *s2aTransportCreds) Clone() credentials.TransportCredentials {
 | 
						|
	info := *c.info
 | 
						|
	var localIdentity *commonpb.Identity
 | 
						|
	if c.localIdentity != nil {
 | 
						|
		localIdentity = proto.Clone(c.localIdentity).(*commonpb.Identity)
 | 
						|
	}
 | 
						|
	var localIdentities []*commonpb.Identity
 | 
						|
	if c.localIdentities != nil {
 | 
						|
		localIdentities = make([]*commonpb.Identity, len(c.localIdentities))
 | 
						|
		for i, localIdentity := range c.localIdentities {
 | 
						|
			localIdentities[i] = proto.Clone(localIdentity).(*commonpb.Identity)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	var targetIdentities []*commonpb.Identity
 | 
						|
	if c.targetIdentities != nil {
 | 
						|
		targetIdentities = make([]*commonpb.Identity, len(c.targetIdentities))
 | 
						|
		for i, targetIdentity := range c.targetIdentities {
 | 
						|
			targetIdentities[i] = proto.Clone(targetIdentity).(*commonpb.Identity)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return &s2aTransportCreds{
 | 
						|
		info:             &info,
 | 
						|
		minTLSVersion:    c.minTLSVersion,
 | 
						|
		maxTLSVersion:    c.maxTLSVersion,
 | 
						|
		tlsCiphersuites:  c.tlsCiphersuites,
 | 
						|
		localIdentity:    localIdentity,
 | 
						|
		localIdentities:  localIdentities,
 | 
						|
		targetIdentities: targetIdentities,
 | 
						|
		isClient:         c.isClient,
 | 
						|
		s2aAddr:          c.s2aAddr,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (c *s2aTransportCreds) OverrideServerName(serverNameOverride string) error {
 | 
						|
	c.info.ServerName = serverNameOverride
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// TLSClientConfigOptions specifies parameters for creating client TLS config.
 | 
						|
type TLSClientConfigOptions struct {
 | 
						|
	// ServerName is required by s2a as the expected name when verifying the hostname found in server's certificate.
 | 
						|
	// 		tlsConfig, _ := factory.Build(ctx, &s2a.TLSClientConfigOptions{
 | 
						|
	//			ServerName: "example.com",
 | 
						|
	//		})
 | 
						|
	ServerName string
 | 
						|
}
 | 
						|
 | 
						|
// TLSClientConfigFactory defines the interface for a client TLS config factory.
 | 
						|
type TLSClientConfigFactory interface {
 | 
						|
	Build(ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error)
 | 
						|
}
 | 
						|
 | 
						|
// NewTLSClientConfigFactory returns an instance of s2aTLSClientConfigFactory.
 | 
						|
func NewTLSClientConfigFactory(opts *ClientOptions) (TLSClientConfigFactory, error) {
 | 
						|
	if opts == nil {
 | 
						|
		return nil, fmt.Errorf("opts must be non-nil")
 | 
						|
	}
 | 
						|
	if opts.EnableLegacyMode {
 | 
						|
		return nil, fmt.Errorf("NewTLSClientConfigFactory only supports S2Av2")
 | 
						|
	}
 | 
						|
	tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
 | 
						|
	if err != nil {
 | 
						|
		// The only possible error is: access token not set in the environment,
 | 
						|
		// which is okay in environments other than serverless.
 | 
						|
		grpclog.Infof("Access token manager not initialized: %v", err)
 | 
						|
		return &s2aTLSClientConfigFactory{
 | 
						|
			s2av2Address:              opts.S2AAddress,
 | 
						|
			tokenManager:              nil,
 | 
						|
			verificationMode:          getVerificationMode(opts.VerificationMode),
 | 
						|
			serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
 | 
						|
		}, nil
 | 
						|
	}
 | 
						|
	return &s2aTLSClientConfigFactory{
 | 
						|
		s2av2Address:              opts.S2AAddress,
 | 
						|
		tokenManager:              tokenManager,
 | 
						|
		verificationMode:          getVerificationMode(opts.VerificationMode),
 | 
						|
		serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
 | 
						|
	}, nil
 | 
						|
}
 | 
						|
 | 
						|
type s2aTLSClientConfigFactory struct {
 | 
						|
	s2av2Address              string
 | 
						|
	tokenManager              tokenmanager.AccessTokenManager
 | 
						|
	verificationMode          s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
 | 
						|
	serverAuthorizationPolicy []byte
 | 
						|
}
 | 
						|
 | 
						|
func (f *s2aTLSClientConfigFactory) Build(
 | 
						|
	ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error) {
 | 
						|
	serverName := ""
 | 
						|
	if opts != nil && opts.ServerName != "" {
 | 
						|
		serverName = opts.ServerName
 | 
						|
	}
 | 
						|
	return v2.NewClientTLSConfig(ctx, f.s2av2Address, f.tokenManager, f.verificationMode, serverName, f.serverAuthorizationPolicy)
 | 
						|
}
 | 
						|
 | 
						|
func getVerificationMode(verificationMode VerificationModeType) s2av2pb.ValidatePeerCertificateChainReq_VerificationMode {
 | 
						|
	switch verificationMode {
 | 
						|
	case ConnectToGoogle:
 | 
						|
		return s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE
 | 
						|
	case Spiffe:
 | 
						|
		return s2av2pb.ValidatePeerCertificateChainReq_SPIFFE
 | 
						|
	default:
 | 
						|
		return s2av2pb.ValidatePeerCertificateChainReq_UNSPECIFIED
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// NewS2ADialTLSContextFunc returns a dialer which establishes an MTLS connection using S2A.
 | 
						|
// Example use with http.RoundTripper:
 | 
						|
//
 | 
						|
//		dialTLSContext := s2a.NewS2aDialTLSContextFunc(&s2a.ClientOptions{
 | 
						|
//			S2AAddress:         s2aAddress, // required
 | 
						|
//		})
 | 
						|
//	 	transport := http.DefaultTransport
 | 
						|
//	 	transport.DialTLSContext = dialTLSContext
 | 
						|
func NewS2ADialTLSContextFunc(opts *ClientOptions) func(ctx context.Context, network, addr string) (net.Conn, error) {
 | 
						|
 | 
						|
	return func(ctx context.Context, network, addr string) (net.Conn, error) {
 | 
						|
 | 
						|
		fallback := func(err error) (net.Conn, error) {
 | 
						|
			if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackDialer != nil &&
 | 
						|
				opts.FallbackOpts.FallbackDialer.Dialer != nil && opts.FallbackOpts.FallbackDialer.ServerAddr != "" {
 | 
						|
				fbDialer := opts.FallbackOpts.FallbackDialer
 | 
						|
				grpclog.Infof("fall back to dial: %s", fbDialer.ServerAddr)
 | 
						|
				fbConn, fbErr := fbDialer.Dialer.DialContext(ctx, network, fbDialer.ServerAddr)
 | 
						|
				if fbErr != nil {
 | 
						|
					return nil, fmt.Errorf("error fallback to %s: %v; S2A error: %w", fbDialer.ServerAddr, fbErr, err)
 | 
						|
				}
 | 
						|
				return fbConn, nil
 | 
						|
			}
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
 | 
						|
		factory, err := NewTLSClientConfigFactory(opts)
 | 
						|
		if err != nil {
 | 
						|
			grpclog.Infof("error creating S2A client config factory: %v", err)
 | 
						|
			return fallback(err)
 | 
						|
		}
 | 
						|
 | 
						|
		serverName, _, err := net.SplitHostPort(addr)
 | 
						|
		if err != nil {
 | 
						|
			serverName = addr
 | 
						|
		}
 | 
						|
		timeoutCtx, cancel := context.WithTimeout(ctx, v2.GetS2ATimeout())
 | 
						|
		defer cancel()
 | 
						|
		s2aTLSConfig, err := factory.Build(timeoutCtx, &TLSClientConfigOptions{
 | 
						|
			ServerName: serverName,
 | 
						|
		})
 | 
						|
		if err != nil {
 | 
						|
			grpclog.Infof("error building S2A TLS config: %v", err)
 | 
						|
			return fallback(err)
 | 
						|
		}
 | 
						|
 | 
						|
		s2aDialer := &tls.Dialer{
 | 
						|
			Config: s2aTLSConfig,
 | 
						|
		}
 | 
						|
		c, err := s2aDialer.DialContext(ctx, network, addr)
 | 
						|
		if err != nil {
 | 
						|
			grpclog.Infof("error dialing with S2A to %s: %v", addr, err)
 | 
						|
			return fallback(err)
 | 
						|
		}
 | 
						|
		grpclog.Infof("success dialing MTLS to %s with S2A", addr)
 | 
						|
		return c, nil
 | 
						|
	}
 | 
						|
}
 |