ggml-webgpu: Check earlier for WebGPU required features (llama/23879)
This commit is contained in:
parent
acd91d2c38
commit
9147a9676b
|
|
@ -3724,7 +3724,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
|
|||
ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
|
||||
}
|
||||
|
||||
static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
|
|
@ -3762,10 +3762,6 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
|||
ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size();
|
||||
ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches();
|
||||
ctx->webgpu_global_ctx->vendor = info.vendor;
|
||||
wgpu::SupportedFeatures features;
|
||||
ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
|
||||
// we require f16 support
|
||||
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
|
||||
ctx->webgpu_global_ctx->capabilities.supports_subgroups =
|
||||
ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups);
|
||||
// for dot4I8packed
|
||||
|
|
@ -3877,7 +3873,6 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
|||
"device_desc: %s\n",
|
||||
info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
|
||||
std::string(info.device).c_str(), std::string(info.description).c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
|
||||
|
|
@ -4507,7 +4502,12 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
|||
UINT64_MAX);
|
||||
}
|
||||
|
||||
if (adapter != nullptr) {
|
||||
// WebGPU backend requires f16 support and, on native, implicit device synchronization.
|
||||
if (adapter != nullptr && adapter.HasFeature(wgpu::FeatureName::ShaderF16)
|
||||
#ifndef __EMSCRIPTEN__
|
||||
&& adapter.HasFeature(wgpu::FeatureName::ImplicitDeviceSynchronization)
|
||||
#endif
|
||||
) {
|
||||
ctx->device_count = 1;
|
||||
}
|
||||
|
||||
|
|
@ -4515,8 +4515,11 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
|||
}
|
||||
|
||||
ggml_backend_t ggml_backend_webgpu_init(void) {
|
||||
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
|
||||
|
||||
ggml_backend_reg_t reg = ggml_backend_webgpu_reg();
|
||||
if (ggml_backend_reg_dev_count(reg) == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(reg, 0);
|
||||
return ggml_backend_webgpu_backend_init(dev, nullptr);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue