metal : make the backend async (llama/15906)

This commit is contained in:
Georgi Gerganov 2025-09-20 13:44:27 +03:00
parent e2c7f1cccd
commit 7eae055e61
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 674 additions and 339 deletions

View File

@ -43,14 +43,8 @@ GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void);
GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend);
GGML_DEPRECATED(
GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
"obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713");
GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
// helper to check if the device supports a specific family
// ideally, the user code should be doing these checks
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf

File diff suppressed because it is too large Load Diff

View File

@ -5571,38 +5571,6 @@ kernel void kernel_flash_attn_ext_vec_reduce(
#undef DV
}
template<typename T>
kernel void kernel_set(
constant ggml_metal_kargs_set & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i13 = tgpig[2];
const int i12 = tgpig[1];
const int i11 = tgpig[0];
const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10;
const int64_t i3 = n / (args.ne12*args.ne11*args.ne10);
const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10);
const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10;
device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs);
for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) {
device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10);
dst_data[i10] = (T) src[0];
}
}
typedef decltype(kernel_set<float>) kernel_set_t;
template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set<float>;
template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set<int32_t>;
template<typename T0, typename T1>
kernel void kernel_cpy(
constant ggml_metal_kargs_cpy & args,