package controller import ( "context" "crypto/subtle" "errors" "net/http" "net/url" "strings" "time" "github.com/cirruslabs/orchard/api" storepkg "github.com/cirruslabs/orchard/internal/controller/store" "github.com/cirruslabs/orchard/internal/responder" "github.com/cirruslabs/orchard/internal/vmtempauth" v1pkg "github.com/cirruslabs/orchard/pkg/resource/v1" "github.com/cirruslabs/orchard/rpc" "github.com/deckarep/golang-set/v2" ginzap "github.com/gin-contrib/zap" "github.com/gin-gonic/gin" "github.com/go-openapi/runtime/middleware" "github.com/penglongli/gin-metrics/ginmetrics" "github.com/samber/lo" "go.uber.org/zap" "google.golang.org/grpc/metadata" ) const ctxServiceAccountKey = "service-account" const ctxVMAccessTokenClaimsKey = "vm-access-token-claims" var ErrUnauthorized = errors.New("unauthorized") func (controller *Controller) initAPI() *gin.Engine { ginEngine := gin.New() var group *gin.RouterGroup if controller.apiPrefix != "" { group = ginEngine.Group(controller.apiPrefix) } else { group = ginEngine.Group("/") } group.Use( ginzap.Ginzap(controller.logger.Desugar(), "", true), ginzap.RecoveryWithZap(controller.logger.Desugar(), true), ) // expose metrics monitor := ginmetrics.GetMonitor() monitor.SetMetricPath("/metrics") monitor.Use(group) // v1 API v1 := group.Group("/v1") // Auth v1.Use(controller.authenticateMiddleware) // OpenAPI docs/spec (if enabled) and a way to for the clients // to check that the API is working v1.GET("/", func(c *gin.Context) { if controller.enableSwaggerDocs { apiURL := &url.URL{ Path: "/", } apiURL = apiURL.JoinPath(controller.apiPrefix, "v1") middleware.SwaggerUI(middleware.SwaggerUIOpts{ Path: apiURL.Path, SpecURL: apiURL.JoinPath("openapi.yaml").Path, }, nil).ServeHTTP(c.Writer, c.Request) } else { c.Status(http.StatusOK) } }) if controller.enableSwaggerDocs { v1.GET("/openapi.yaml", func(c *gin.Context) { c.Data(200, "text/yaml", api.Spec) }) } // Controller information v1.GET("/controller/info", func(c *gin.Context) { controller.controllerInfo(c).Respond(c) }) // Cluster settings v1.GET("/cluster-settings", func(c *gin.Context) { controller.getClusterSettings(c).Respond(c) }) v1.PUT("/cluster-settings", func(c *gin.Context) { controller.updateClusterSettings(c).Respond(c) }) // Service accounts v1.POST("/service-accounts", func(c *gin.Context) { controller.createServiceAccount(c).Respond(c) }) v1.PUT("/service-accounts/:name", func(c *gin.Context) { controller.updateServiceAccount(c).Respond(c) }) v1.GET("/service-accounts/:name", func(c *gin.Context) { controller.getServiceAccount(c).Respond(c) }) v1.GET("/service-accounts", func(c *gin.Context) { controller.listServiceAccounts(c).Respond(c) }) v1.DELETE("/service-accounts/:name", func(c *gin.Context) { controller.deleteServiceAccount(c).Respond(c) }) // Workers v1.POST("/workers", func(c *gin.Context) { controller.createWorker(c).Respond(c) }) v1.PUT("/workers/:name", func(c *gin.Context) { controller.updateWorker(c).Respond(c) }) v1.GET("/workers/:name", func(c *gin.Context) { controller.getWorker(c).Respond(c) }) v1.GET("/workers", func(c *gin.Context) { controller.listWorkers(c).Respond(c) }) v1.GET("/workers/:name/port-forward", func(c *gin.Context) { controller.portForwardWorker(c).Respond(c) }) v1.DELETE("/workers/:name", func(c *gin.Context) { controller.deleteWorker(c).Respond(c) }) // RPC v2 v1.GET("/rpc/watch", func(c *gin.Context) { controller.rpcWatch(c).Respond(c) }) v1.GET("/rpc/port-forward", func(c *gin.Context) { controller.rpcPortForward(c).Respond(c) }) v1.POST("/rpc/resolve-ip", func(c *gin.Context) { controller.rpcResolveIP(c).Respond(c) }) // VMs v1.POST("/vms", func(c *gin.Context) { controller.createVM(c).Respond(c) }) v1.PUT("/vms/:name", func(c *gin.Context) { if strings.HasPrefix(c.GetHeader("User-Agent"), "Orchard/0") { // Backward compatibility for older Orchard Workers that still // use the PUT /vms/{name} API endpoint to update a VM status // // Note that we include the "0" here to avoid targeting users // of the github.com/cirruslabs/orchard/pkg/client package. For // them, the UA string should normally be "Orchard/unknown-unknown". // // After some months/years we can remove this workaround and at // the very worst the workers simply won't progress with the VMs // assigned to them. An upgrade to a newer version will fix that. controller.updateVMState(c).Respond(c) } else { controller.updateVMSpec(c).Respond(c) } }) v1.PUT("/vms/:name/state", func(c *gin.Context) { controller.updateVMState(c).Respond(c) }) v1.GET("/vms/:name", func(c *gin.Context) { controller.getVM(c).Respond(c) }) v1.GET("/vms", func(c *gin.Context) { controller.listVMs(c).Respond(c) }) v1.GET("/vms/:name/port-forward", func(c *gin.Context) { controller.portForwardVM(c).Respond(c) }) v1.GET("/vms/:name/ip", func(c *gin.Context) { controller.ip(c).Respond(c) }) v1.DELETE("/vms/:name", func(c *gin.Context) { controller.deleteVM(c).Respond(c) }) v1.GET("/vms/:name/events", func(c *gin.Context) { controller.listVMEvents(c).Respond(c) }) v1.POST("/vms/:name/events", func(c *gin.Context) { controller.appendVMEvents(c).Respond(c) }) v1.POST("/vms/:name/access-tokens", func(c *gin.Context) { controller.issueVMAccessToken(c).Respond(c) }) return ginEngine } func (controller *Controller) fetchServiceAccount(name string, token string) (*v1pkg.ServiceAccount, error) { var serviceAccount *v1pkg.ServiceAccount var err error err = controller.store.View(func(txn storepkg.Transaction) error { serviceAccount, err = txn.GetServiceAccount(name) if err != nil { return err } return nil }) if err != nil { if errors.Is(err, storepkg.ErrNotFound) { return nil, ErrUnauthorized } return nil, err } if subtle.ConstantTimeCompare([]byte(serviceAccount.Token), []byte(token)) == 0 { return nil, ErrUnauthorized } return serviceAccount, nil } func (controller *Controller) authenticateMiddleware(c *gin.Context) { authHeader := strings.TrimSpace(c.GetHeader("Authorization")) if len(authHeader) >= len("Bearer ") && strings.EqualFold(authHeader[:len("Bearer ")], "Bearer ") { token := strings.TrimSpace(authHeader[len("Bearer "):]) if token == "" { responder.Code(http.StatusUnauthorized).Respond(c) return } claims, err := vmtempauth.Verify(controller.vmAccessTokenKey, token, time.Now().UTC()) if err != nil { responder.Code(http.StatusUnauthorized).Respond(c) return } c.Set(ctxVMAccessTokenClaimsKey, claims) c.Next() return } // Retrieve presented credentials (if any) user, password, ok := c.Request.BasicAuth() if !ok { c.Next() return } serviceAccount, err := controller.fetchServiceAccount(user, password) if err != nil { if errors.Is(err, ErrUnauthorized) { responder.Code(http.StatusUnauthorized).Respond(c) } else { responder.Error(err).Respond(c) } return } // Remember service account for further authorize() calls c.Set(ctxServiceAccountKey, serviceAccount) c.Next() } func (controller *Controller) serviceAccountFromContext(ctx *gin.Context) (*v1pkg.ServiceAccount, bool) { untypeServiceAccount, ok := ctx.Get(ctxServiceAccountKey) if !ok { return nil, false } serviceAccount, ok := untypeServiceAccount.(*v1pkg.ServiceAccount) if !ok { return nil, false } return serviceAccount, true } func (controller *Controller) vmAccessTokenClaimsFromContext(ctx *gin.Context) (*vmtempauth.Claims, bool) { untypeClaims, ok := ctx.Get(ctxVMAccessTokenClaimsKey) if !ok { return nil, false } claims, ok := untypeClaims.(*vmtempauth.Claims) if !ok { return nil, false } return claims, true } type AuthorizeMode int const ( AuthorizeModeAll AuthorizeMode = iota AuthorizeModeAny ) func (controller *Controller) authorize( ctx *gin.Context, requiredRoles ...v1pkg.ServiceAccountRole, ) responder.Responder { return controller.authorizeBase(ctx, AuthorizeModeAll, requiredRoles...) } func (controller *Controller) authorizeAny( ctx *gin.Context, requiredRoles ...v1pkg.ServiceAccountRole, ) responder.Responder { return controller.authorizeBase(ctx, AuthorizeModeAny, requiredRoles...) } func (controller *Controller) authorizeBase( ctx *gin.Context, mode AuthorizeMode, requiredRoles ...v1pkg.ServiceAccountRole, ) responder.Responder { if controller.insecureAuthDisabled { return nil } serviceAccount, ok := controller.serviceAccountFromContext(ctx) if !ok { return responder.Code(http.StatusUnauthorized) } serviceAccountRolesSet := mapset.NewSet[v1pkg.ServiceAccountRole](serviceAccount.Roles...) var authorized bool switch mode { case AuthorizeModeAll: authorized = serviceAccountRolesSet.Contains(requiredRoles...) case AuthorizeModeAny: authorized = serviceAccountRolesSet.ContainsAny(requiredRoles...) } if authorized { return nil } var hint string switch mode { case AuthorizeModeAll: hint = "all of the following roles must be present" case AuthorizeModeAny: hint = "any of the following roles must be present" } humanizedRoles := lo.Map(requiredRoles, func(role v1pkg.ServiceAccountRole, _ int) string { return string(role) }) return responder.JSON(http.StatusUnauthorized, NewErrorResponse("%s: %s", hint, strings.Join(humanizedRoles, ", "))) } func (controller *Controller) authorizeGRPC(ctx context.Context, scopes ...v1pkg.ServiceAccountRole) bool { if controller.insecureAuthDisabled { return true } name := metadata.ValueFromIncomingContext(ctx, rpc.MetadataServiceAccountNameKey) if len(name) != 1 { return false } token := metadata.ValueFromIncomingContext(ctx, rpc.MetadataServiceAccountTokenKey) if len(token) != 1 { return false } serviceAccount, err := controller.fetchServiceAccount(name[0], token[0]) if err != nil { return false } return mapset.NewSet[v1pkg.ServiceAccountRole](serviceAccount.Roles...).Contains(scopes...) } type storeTransactionFunc func(operation func(txn storepkg.Transaction) error) error func (controller *Controller) storeView(view func(txn storepkg.Transaction) responder.Responder) responder.Responder { return adaptResponderToStoreOperation(controller.logger, controller.store.View, view) } func (controller *Controller) storeUpdate( update func(txn storepkg.Transaction) responder.Responder, ) responder.Responder { return adaptResponderToStoreOperation(controller.logger, controller.store.Update, update) } func adaptResponderToStoreOperation( logger *zap.SugaredLogger, storeOperation storeTransactionFunc, responderOperation func(txn storepkg.Transaction) responder.Responder, ) responder.Responder { var result responder.Responder if err := storeOperation(func(txn storepkg.Transaction) error { result = responderOperation(txn) return nil }); err != nil { logger.Errorf("encountered an error during store operation: %v", err) return responder.Code(http.StatusInternalServerError) } return result }