orchard/pkg/client/client.go

400 lines
8.7 KiB
Go

package client
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"github.com/cirruslabs/orchard/internal/config"
"github.com/cirruslabs/orchard/internal/dialer"
"github.com/cirruslabs/orchard/internal/version"
"github.com/cirruslabs/orchard/rpc"
"github.com/coder/websocket"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
)
type APIError struct {
StatusCode int
}
func (apiError *APIError) Error() string {
return "API client encountered an error while attempting"
}
func (apiError *APIError) Is(target error) bool {
_, ok := target.(*APIError)
return ok
}
var (
ErrFailed = errors.New("API client failed")
ErrAPI = &APIError{}
ErrInvalidState = errors.New("invalid state")
)
type Client struct {
address string
insecure bool
trustedCertificate *x509.Certificate
tlsConfig *tls.Config
httpClient *http.Client
baseURL *url.URL
serviceAccountName string
serviceAccountToken string
dialer dialer.Dialer
}
type Config struct {
Address string
TLSConfig *tls.Config
}
func New(opts ...Option) (*Client, error) {
client := &Client{}
// Apply options
for _, opt := range opts {
opt(client)
}
// Apply defaults
if client.address == "" {
if err := client.configureFromDefaultContext(); err != nil {
return nil, err
}
}
if client.trustedCertificate != nil {
privatePool := x509.NewCertPool()
privatePool.AddCert(client.trustedCertificate)
client.tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
RootCAs: privatePool,
}
if len(client.trustedCertificate.DNSNames) != 0 {
client.tlsConfig.ServerName = client.trustedCertificate.DNSNames[0]
}
}
// Instantiate the HTTP client
transport := &http.Transport{
TLSClientConfig: client.tlsConfig,
}
if client.dialer != nil {
transport.DialContext = client.dialer.DialContext
}
client.httpClient = &http.Client{
// The default is zero, which means no timeout, which means that
// the requests may hang indefinitely. See [1] for more details.
//
// [1]: https://github.com/cirruslabs/orchard/issues/152#issuecomment-1927091747
Timeout: 30 * time.Second,
Transport: transport,
}
url, err := url.Parse(client.address)
if err != nil {
return nil, err
}
client.baseURL = url
// Figure out if HTTP (insecure) or HTTPS (secure) was requested,
// so we can further adapt for gRPC and WebSocket usage patterns
switch client.baseURL.Scheme {
case "http":
client.insecure = true
case "https":
// do nothing, we're secure by default
default:
return nil, fmt.Errorf("%w: only http https schemes are supported, got %s",
ErrFailed, client.baseURL.Scheme)
}
return client, nil
}
func (client *Client) GRPCTarget() string {
return client.baseURL.Host
}
func (client *Client) GRPCTransportCredentials() credentials.TransportCredentials {
if client.insecure {
return insecure.NewCredentials()
}
return credentials.NewTLS(client.tlsConfig)
}
func (client *Client) GPRCMetadata() metadata.MD {
result := map[string]string{}
if client.serviceAccountName != "" && client.serviceAccountToken != "" {
result = map[string]string{
rpc.MetadataServiceAccountNameKey: client.serviceAccountName,
rpc.MetadataServiceAccountTokenKey: client.serviceAccountToken,
}
}
return metadata.New(result)
}
func (client *Client) configureFromDefaultContext() error {
configHandle, err := config.NewHandle()
if err != nil {
return err
}
defaultContext, err := configHandle.DefaultContext()
if err != nil {
return err
}
client.address = defaultContext.URL
client.serviceAccountName = defaultContext.ServiceAccountName
client.serviceAccountToken = defaultContext.ServiceAccountToken
if client.trustedCertificate == nil {
client.trustedCertificate, err = defaultContext.TrustedCertificate()
if err != nil {
return err
}
}
return nil
}
func (client *Client) requestWithHeaders(
ctx context.Context,
method string,
path string,
in interface{},
out interface{},
params map[string]string,
) (http.Header, error) {
var body io.Reader
if in != nil {
jsonBytes, err := json.Marshal(in)
if err != nil {
return nil, fmt.Errorf("%w to marshal request body: %v", ErrFailed, err)
}
body = bytes.NewBuffer(jsonBytes)
}
endpointURL := client.formatPath(path)
values := endpointURL.Query()
for key, value := range params {
values.Set(key, value)
}
endpointURL.RawQuery = values.Encode()
request, err := http.NewRequestWithContext(ctx, method, endpointURL.String(), body)
if err != nil {
return nil, fmt.Errorf("%w instantiate a request: %v", ErrFailed, err)
}
client.modifyHeader(request.Header)
response, err := client.httpClient.Do(request)
if err != nil {
return nil, fmt.Errorf("%w to make a request: %v", ErrFailed, err)
}
defer func() {
_ = response.Body.Close()
}()
if response.StatusCode != http.StatusOK {
apiError := &APIError{
StatusCode: response.StatusCode,
}
return nil, fmt.Errorf("%w to make a request: %d %s%s",
apiError, response.StatusCode, http.StatusText(response.StatusCode),
detailsFromErrorResponseBody(response.Body))
}
if out != nil {
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("%w to read response body: %v", ErrAPI, err)
}
if err := json.Unmarshal(bodyBytes, out); err != nil {
return nil, fmt.Errorf("%w to unmarshal response body: %v", ErrAPI, err)
}
}
return response.Header, nil
}
func (client *Client) request(
ctx context.Context,
method string,
path string,
in interface{},
out interface{},
params map[string]string,
) error {
_, err := client.requestWithHeaders(ctx, method, path, in, out, params)
return err
}
func detailsFromErrorResponseBody(body io.Reader) string {
bodyBytes, err := io.ReadAll(body)
if err != nil {
return ""
}
var errorResponse struct {
Message string `json:"message"`
}
if err := json.Unmarshal(bodyBytes, &errorResponse); err != nil {
return ""
}
if errorResponse.Message != "" {
return fmt.Sprintf(" (%s)", errorResponse.Message)
}
return ""
}
func (client *Client) wsRequest(
ctx context.Context,
path string,
params map[string]string,
) (net.Conn, error) {
wsConn, err := client.wsRequestRaw(ctx, path, params)
if err != nil {
return nil, err
}
return websocket.NetConn(ctx, wsConn, websocket.MessageBinary), nil
}
func (client *Client) wsRequestRaw(
ctx context.Context,
path string,
params map[string]string,
) (*websocket.Conn, error) {
endpointURL := client.formatPath(path)
// Adapt HTTP scheme to WebSocket scheme
if client.insecure {
endpointURL.Scheme = "ws"
} else {
endpointURL.Scheme = "wss"
}
values := endpointURL.Query()
for key, value := range params {
values.Set(key, value)
}
endpointURL.RawQuery = values.Encode()
dialOptions := &websocket.DialOptions{
HTTPClient: client.httpClient,
HTTPHeader: make(http.Header),
}
client.modifyHeader(dialOptions.HTTPHeader)
conn, resp, err := websocket.Dial(ctx, endpointURL.String(), dialOptions)
if err != nil {
if resp != nil {
_ = resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
err = fmt.Errorf("%w (are you sure this VM exists on the controller?)", err)
}
}
return nil, err
}
return conn, nil
}
func (client *Client) formatPath(path string) *url.URL {
endpointURL := &url.URL{
Scheme: client.baseURL.Scheme,
User: client.baseURL.User,
Host: client.baseURL.Host,
Path: client.baseURL.Path,
}
return endpointURL.JoinPath("v1", path)
}
func (client *Client) modifyHeader(header http.Header) {
header.Set("User-Agent", fmt.Sprintf("Orchard/%s", version.FullVersion))
if client.serviceAccountName != "" && client.serviceAccountToken != "" {
authPlain := fmt.Sprintf("%s:%s", client.serviceAccountName, client.serviceAccountToken)
authEncoded := base64.StdEncoding.EncodeToString([]byte(authPlain))
header.Set("Authorization", fmt.Sprintf("Basic %s", authEncoded))
}
}
func (client *Client) Check(ctx context.Context) error {
return client.request(ctx, http.MethodGet, "/", nil, nil, nil)
}
func (client *Client) Workers() *WorkersService {
return &WorkersService{
client: client,
}
}
func (client *Client) VMs() *VMsService {
return &VMsService{
client: client,
}
}
func (client *Client) ServiceAccounts() *ServiceAccountsService {
return &ServiceAccountsService{
client: client,
}
}
func (client *Client) Controller() *ControllerService {
return &ControllerService{
client: client,
}
}
func (client *Client) ClusterSettings() *ClusterSettingsService {
return &ClusterSettingsService{
client: client,
}
}
func (client *Client) RPC() *RPCService {
return &RPCService{
client: client,
}
}