ggml-webgpu: Check earlier for WebGPU required features (llama/23879)

This commit is contained in:
Reese Levine 2026-05-29 14:16:05 -07:00 committed by Georgi Gerganov
parent acd91d2c38
commit 9147a9676b
1 changed files with 12 additions and 9 deletions

View File

@ -3724,7 +3724,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
}
static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
wgpu::RequestAdapterOptions options = {};
#ifndef __EMSCRIPTEN__
@ -3762,10 +3762,6 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size();
ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches();
ctx->webgpu_global_ctx->vendor = info.vendor;
wgpu::SupportedFeatures features;
ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
// we require f16 support
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
ctx->webgpu_global_ctx->capabilities.supports_subgroups =
ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups);
// for dot4I8packed
@ -3877,7 +3873,6 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
"device_desc: %s\n",
info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
std::string(info.device).c_str(), std::string(info.description).c_str());
return true;
}
static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
@ -4507,7 +4502,12 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
UINT64_MAX);
}
if (adapter != nullptr) {
// WebGPU backend requires f16 support and, on native, implicit device synchronization.
if (adapter != nullptr && adapter.HasFeature(wgpu::FeatureName::ShaderF16)
#ifndef __EMSCRIPTEN__
&& adapter.HasFeature(wgpu::FeatureName::ImplicitDeviceSynchronization)
#endif
) {
ctx->device_count = 1;
}
@ -4515,8 +4515,11 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
}
ggml_backend_t ggml_backend_webgpu_init(void) {
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
ggml_backend_reg_t reg = ggml_backend_webgpu_reg();
if (ggml_backend_reg_dev_count(reg) == 0) {
return nullptr;
}
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(reg, 0);
return ggml_backend_webgpu_backend_init(dev, nullptr);
}