400 lines
8.7 KiB
Go
400 lines
8.7 KiB
Go
package client
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"time"
|
|
|
|
"github.com/cirruslabs/orchard/internal/config"
|
|
"github.com/cirruslabs/orchard/internal/dialer"
|
|
"github.com/cirruslabs/orchard/internal/version"
|
|
"github.com/cirruslabs/orchard/rpc"
|
|
"github.com/coder/websocket"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/metadata"
|
|
)
|
|
|
|
type APIError struct {
|
|
StatusCode int
|
|
}
|
|
|
|
func (apiError *APIError) Error() string {
|
|
return "API client encountered an error while attempting"
|
|
}
|
|
|
|
func (apiError *APIError) Is(target error) bool {
|
|
_, ok := target.(*APIError)
|
|
return ok
|
|
}
|
|
|
|
var (
|
|
ErrFailed = errors.New("API client failed")
|
|
ErrAPI = &APIError{}
|
|
ErrInvalidState = errors.New("invalid state")
|
|
)
|
|
|
|
type Client struct {
|
|
address string
|
|
insecure bool
|
|
trustedCertificate *x509.Certificate
|
|
tlsConfig *tls.Config
|
|
|
|
httpClient *http.Client
|
|
baseURL *url.URL
|
|
|
|
serviceAccountName string
|
|
serviceAccountToken string
|
|
|
|
dialer dialer.Dialer
|
|
}
|
|
|
|
type Config struct {
|
|
Address string
|
|
TLSConfig *tls.Config
|
|
}
|
|
|
|
func New(opts ...Option) (*Client, error) {
|
|
client := &Client{}
|
|
|
|
// Apply options
|
|
for _, opt := range opts {
|
|
opt(client)
|
|
}
|
|
|
|
// Apply defaults
|
|
if client.address == "" {
|
|
if err := client.configureFromDefaultContext(); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if client.trustedCertificate != nil {
|
|
privatePool := x509.NewCertPool()
|
|
privatePool.AddCert(client.trustedCertificate)
|
|
|
|
client.tlsConfig = &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
RootCAs: privatePool,
|
|
}
|
|
|
|
if len(client.trustedCertificate.DNSNames) != 0 {
|
|
client.tlsConfig.ServerName = client.trustedCertificate.DNSNames[0]
|
|
}
|
|
}
|
|
|
|
// Instantiate the HTTP client
|
|
transport := &http.Transport{
|
|
TLSClientConfig: client.tlsConfig,
|
|
}
|
|
|
|
if client.dialer != nil {
|
|
transport.DialContext = client.dialer.DialContext
|
|
}
|
|
|
|
client.httpClient = &http.Client{
|
|
// The default is zero, which means no timeout, which means that
|
|
// the requests may hang indefinitely. See [1] for more details.
|
|
//
|
|
// [1]: https://github.com/cirruslabs/orchard/issues/152#issuecomment-1927091747
|
|
Timeout: 30 * time.Second,
|
|
Transport: transport,
|
|
}
|
|
|
|
url, err := url.Parse(client.address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
client.baseURL = url
|
|
|
|
// Figure out if HTTP (insecure) or HTTPS (secure) was requested,
|
|
// so we can further adapt for gRPC and WebSocket usage patterns
|
|
switch client.baseURL.Scheme {
|
|
case "http":
|
|
client.insecure = true
|
|
case "https":
|
|
// do nothing, we're secure by default
|
|
default:
|
|
return nil, fmt.Errorf("%w: only http https schemes are supported, got %s",
|
|
ErrFailed, client.baseURL.Scheme)
|
|
}
|
|
|
|
return client, nil
|
|
}
|
|
|
|
func (client *Client) GRPCTarget() string {
|
|
return client.baseURL.Host
|
|
}
|
|
|
|
func (client *Client) GRPCTransportCredentials() credentials.TransportCredentials {
|
|
if client.insecure {
|
|
return insecure.NewCredentials()
|
|
}
|
|
|
|
return credentials.NewTLS(client.tlsConfig)
|
|
}
|
|
|
|
func (client *Client) GPRCMetadata() metadata.MD {
|
|
result := map[string]string{}
|
|
|
|
if client.serviceAccountName != "" && client.serviceAccountToken != "" {
|
|
result = map[string]string{
|
|
rpc.MetadataServiceAccountNameKey: client.serviceAccountName,
|
|
rpc.MetadataServiceAccountTokenKey: client.serviceAccountToken,
|
|
}
|
|
}
|
|
|
|
return metadata.New(result)
|
|
}
|
|
|
|
func (client *Client) configureFromDefaultContext() error {
|
|
configHandle, err := config.NewHandle()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
defaultContext, err := configHandle.DefaultContext()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
client.address = defaultContext.URL
|
|
client.serviceAccountName = defaultContext.ServiceAccountName
|
|
client.serviceAccountToken = defaultContext.ServiceAccountToken
|
|
|
|
if client.trustedCertificate == nil {
|
|
client.trustedCertificate, err = defaultContext.TrustedCertificate()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (client *Client) requestWithHeaders(
|
|
ctx context.Context,
|
|
method string,
|
|
path string,
|
|
in interface{},
|
|
out interface{},
|
|
params map[string]string,
|
|
) (http.Header, error) {
|
|
var body io.Reader
|
|
|
|
if in != nil {
|
|
jsonBytes, err := json.Marshal(in)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w to marshal request body: %v", ErrFailed, err)
|
|
}
|
|
|
|
body = bytes.NewBuffer(jsonBytes)
|
|
}
|
|
|
|
endpointURL := client.formatPath(path)
|
|
|
|
values := endpointURL.Query()
|
|
for key, value := range params {
|
|
values.Set(key, value)
|
|
}
|
|
endpointURL.RawQuery = values.Encode()
|
|
|
|
request, err := http.NewRequestWithContext(ctx, method, endpointURL.String(), body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w instantiate a request: %v", ErrFailed, err)
|
|
}
|
|
|
|
client.modifyHeader(request.Header)
|
|
|
|
response, err := client.httpClient.Do(request)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w to make a request: %v", ErrFailed, err)
|
|
}
|
|
defer func() {
|
|
_ = response.Body.Close()
|
|
}()
|
|
|
|
if response.StatusCode != http.StatusOK {
|
|
apiError := &APIError{
|
|
StatusCode: response.StatusCode,
|
|
}
|
|
|
|
return nil, fmt.Errorf("%w to make a request: %d %s%s",
|
|
apiError, response.StatusCode, http.StatusText(response.StatusCode),
|
|
detailsFromErrorResponseBody(response.Body))
|
|
}
|
|
|
|
if out != nil {
|
|
bodyBytes, err := io.ReadAll(response.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w to read response body: %v", ErrAPI, err)
|
|
}
|
|
|
|
if err := json.Unmarshal(bodyBytes, out); err != nil {
|
|
return nil, fmt.Errorf("%w to unmarshal response body: %v", ErrAPI, err)
|
|
}
|
|
}
|
|
|
|
return response.Header, nil
|
|
}
|
|
|
|
func (client *Client) request(
|
|
ctx context.Context,
|
|
method string,
|
|
path string,
|
|
in interface{},
|
|
out interface{},
|
|
params map[string]string,
|
|
) error {
|
|
_, err := client.requestWithHeaders(ctx, method, path, in, out, params)
|
|
return err
|
|
}
|
|
|
|
func detailsFromErrorResponseBody(body io.Reader) string {
|
|
bodyBytes, err := io.ReadAll(body)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
var errorResponse struct {
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
if err := json.Unmarshal(bodyBytes, &errorResponse); err != nil {
|
|
return ""
|
|
}
|
|
|
|
if errorResponse.Message != "" {
|
|
return fmt.Sprintf(" (%s)", errorResponse.Message)
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func (client *Client) wsRequest(
|
|
ctx context.Context,
|
|
path string,
|
|
params map[string]string,
|
|
) (net.Conn, error) {
|
|
wsConn, err := client.wsRequestRaw(ctx, path, params)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return websocket.NetConn(ctx, wsConn, websocket.MessageBinary), nil
|
|
}
|
|
|
|
func (client *Client) wsRequestRaw(
|
|
ctx context.Context,
|
|
path string,
|
|
params map[string]string,
|
|
) (*websocket.Conn, error) {
|
|
endpointURL := client.formatPath(path)
|
|
|
|
// Adapt HTTP scheme to WebSocket scheme
|
|
if client.insecure {
|
|
endpointURL.Scheme = "ws"
|
|
} else {
|
|
endpointURL.Scheme = "wss"
|
|
}
|
|
|
|
values := endpointURL.Query()
|
|
for key, value := range params {
|
|
values.Set(key, value)
|
|
}
|
|
endpointURL.RawQuery = values.Encode()
|
|
|
|
dialOptions := &websocket.DialOptions{
|
|
HTTPClient: client.httpClient,
|
|
HTTPHeader: make(http.Header),
|
|
}
|
|
|
|
client.modifyHeader(dialOptions.HTTPHeader)
|
|
|
|
conn, resp, err := websocket.Dial(ctx, endpointURL.String(), dialOptions)
|
|
if err != nil {
|
|
if resp != nil {
|
|
_ = resp.Body.Close()
|
|
|
|
if resp.StatusCode == http.StatusNotFound {
|
|
err = fmt.Errorf("%w (are you sure this VM exists on the controller?)", err)
|
|
}
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
func (client *Client) formatPath(path string) *url.URL {
|
|
endpointURL := &url.URL{
|
|
Scheme: client.baseURL.Scheme,
|
|
User: client.baseURL.User,
|
|
Host: client.baseURL.Host,
|
|
Path: client.baseURL.Path,
|
|
}
|
|
|
|
return endpointURL.JoinPath("v1", path)
|
|
}
|
|
|
|
func (client *Client) modifyHeader(header http.Header) {
|
|
header.Set("User-Agent", fmt.Sprintf("Orchard/%s", version.FullVersion))
|
|
|
|
if client.serviceAccountName != "" && client.serviceAccountToken != "" {
|
|
authPlain := fmt.Sprintf("%s:%s", client.serviceAccountName, client.serviceAccountToken)
|
|
authEncoded := base64.StdEncoding.EncodeToString([]byte(authPlain))
|
|
header.Set("Authorization", fmt.Sprintf("Basic %s", authEncoded))
|
|
}
|
|
}
|
|
|
|
func (client *Client) Check(ctx context.Context) error {
|
|
return client.request(ctx, http.MethodGet, "/", nil, nil, nil)
|
|
}
|
|
|
|
func (client *Client) Workers() *WorkersService {
|
|
return &WorkersService{
|
|
client: client,
|
|
}
|
|
}
|
|
|
|
func (client *Client) VMs() *VMsService {
|
|
return &VMsService{
|
|
client: client,
|
|
}
|
|
}
|
|
|
|
func (client *Client) ServiceAccounts() *ServiceAccountsService {
|
|
return &ServiceAccountsService{
|
|
client: client,
|
|
}
|
|
}
|
|
|
|
func (client *Client) Controller() *ControllerService {
|
|
return &ControllerService{
|
|
client: client,
|
|
}
|
|
}
|
|
|
|
func (client *Client) ClusterSettings() *ClusterSettingsService {
|
|
return &ClusterSettingsService{
|
|
client: client,
|
|
}
|
|
}
|
|
|
|
func (client *Client) RPC() *RPCService {
|
|
return &RPCService{
|
|
client: client,
|
|
}
|
|
}
|