model : add ASR support for LFM2-Audio-1.5B (conformer) (llama/18106)
* ASR with LFM2-Audio-1.5B * Set rope_theta * Fix comment * Remove rope_theta setting * Address PR feedback * rename functions to conformer * remove some redundant ggml_cont * fix missing tensor * add prefix "a." for conv tensors * remove redundant reshape * clean up * add test model --------- Co-authored-by: Tarek Dakhran <tarek@liquid.ai>
This commit is contained in:
parent
956fac433b
commit
fea481f412
|
|
@ -102,31 +102,25 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
|||
const int threads = 128;
|
||||
GGML_ASSERT(nr % threads == 0);
|
||||
|
||||
if (n_t <= 32) {
|
||||
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
|
||||
if (nc == 4) {
|
||||
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else if (nc == 3) {
|
||||
ssm_conv_f32<threads, 3><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
auto launch_kernel = [&](auto NC) {
|
||||
constexpr int kNC = decltype(NC)::value;
|
||||
if (n_t <= 32) {
|
||||
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
|
||||
ssm_conv_f32<threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else {
|
||||
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
|
||||
}
|
||||
} else {
|
||||
if (nc == 4) {
|
||||
const int64_t split_n_t = 32;
|
||||
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
|
||||
ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>(
|
||||
ssm_conv_long_token_f32<threads, kNC, split_n_t><<<blocks, threads, 0, stream>>>(
|
||||
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else if (nc == 3) {
|
||||
const int64_t split_n_t = 32;
|
||||
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
|
||||
ssm_conv_long_token_f32<threads, 3, split_n_t><<<blocks, threads, 0, stream>>>(
|
||||
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else {
|
||||
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
|
||||
}
|
||||
};
|
||||
|
||||
switch (nc) {
|
||||
case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
|
||||
case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
|
||||
case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
|
||||
default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue