whisper.cpp/src/parakeet-arch.h

189 lines
10 KiB
C++

#pragma once
#include "ggml.h"
#include <map>
enum parakeet_tensor {
// Encoder pre_encode
PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT,
PARAKEET_TENSOR_ENC_PRE_OUT_BIAS,
PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT,
PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS,
PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT,
PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS,
PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT,
PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS,
PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT,
PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS,
PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT,
PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS,
// Encoder layers (per-layer)
PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_FF1_BIAS,
PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT,
PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_CONV_BIAS,
PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT,
PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT,
PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT,
PARAKEET_TENSOR_ENC_CONV_BN_BIAS,
PARAKEET_TENSOR_ENC_CONV_BN_MEAN,
PARAKEET_TENSOR_ENC_CONV_BN_VAR,
PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES,
PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS,
PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U,
PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V,
PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT,
PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT,
PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT,
PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT,
PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_FF2_BIAS,
PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT,
PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT,
PARAKEET_TENSOR_ENC_NORM_OUT_BIAS,
// Prediction network
PARAKEET_TENSOR_PRED_EMBED_WEIGHT,
PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH,
PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH,
PARAKEET_TENSOR_PRED_LSTM_BIAS_H,
// Joint network
PARAKEET_TENSOR_JOINT_PRED_WEIGHT,
PARAKEET_TENSOR_JOINT_PRED_BIAS,
PARAKEET_TENSOR_JOINT_ENC_WEIGHT,
PARAKEET_TENSOR_JOINT_ENC_BIAS,
PARAKEET_TENSOR_JOINT_NET_WEIGHT,
PARAKEET_TENSOR_JOINT_NET_BIAS,
};
static const std::map<parakeet_tensor, const char *> PARAKEET_TENSOR_NAMES = {
// Encoder pre_encode
{PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, "encoder.pre_encode.out.weight"},
{PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, "encoder.pre_encode.out.bias"},
{PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, "encoder.pre_encode.conv.0.weight"},
{PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, "encoder.pre_encode.conv.0.bias"},
{PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, "encoder.pre_encode.conv.2.weight"},
{PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, "encoder.pre_encode.conv.2.bias"},
{PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, "encoder.pre_encode.conv.3.weight"},
{PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, "encoder.pre_encode.conv.3.bias"},
{PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, "encoder.pre_encode.conv.5.weight"},
{PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, "encoder.pre_encode.conv.5.bias"},
{PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, "encoder.pre_encode.conv.6.weight"},
{PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, "encoder.pre_encode.conv.6.bias"},
// Encoder layers (use %d for layer number)
{PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, "encoder.layers.%d.norm_feed_forward1.weight"},
{PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, "encoder.layers.%d.norm_feed_forward1.bias"},
{PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward1.linear1.weight"},
{PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward1.linear2.weight"},
{PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, "encoder.layers.%d.norm_conv.weight"},
{PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, "encoder.layers.%d.norm_conv.bias"},
{PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, "encoder.layers.%d.conv.pointwise_conv1.weight"},
{PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, "encoder.layers.%d.conv.depthwise_conv.weight"},
{PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, "encoder.layers.%d.conv.batch_norm.weight"},
{PARAKEET_TENSOR_ENC_CONV_BN_BIAS, "encoder.layers.%d.conv.batch_norm.bias"},
{PARAKEET_TENSOR_ENC_CONV_BN_MEAN, "encoder.layers.%d.conv.batch_norm.running_mean"},
{PARAKEET_TENSOR_ENC_CONV_BN_VAR, "encoder.layers.%d.conv.batch_norm.running_var"},
{PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, "encoder.layers.%d.conv.batch_norm.num_batches_tracked"},
{PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, "encoder.layers.%d.conv.pointwise_conv2.weight"},
{PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, "encoder.layers.%d.norm_self_att.weight"},
{PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, "encoder.layers.%d.norm_self_att.bias"},
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, "encoder.layers.%d.self_attn.pos_bias_u"},
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, "encoder.layers.%d.self_attn.pos_bias_v"},
{PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, "encoder.layers.%d.self_attn.linear_q.weight"},
{PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, "encoder.layers.%d.self_attn.linear_k.weight"},
{PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, "encoder.layers.%d.self_attn.linear_v.weight"},
{PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, "encoder.layers.%d.self_attn.linear_out.weight"},
{PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, "encoder.layers.%d.self_attn.linear_pos.weight"},
{PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, "encoder.layers.%d.norm_feed_forward2.weight"},
{PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, "encoder.layers.%d.norm_feed_forward2.bias"},
{PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward2.linear1.weight"},
{PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward2.linear2.weight"},
{PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, "encoder.layers.%d.norm_out.weight"},
{PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, "encoder.layers.%d.norm_out.bias"},
// Prediction network
{PARAKEET_TENSOR_PRED_EMBED_WEIGHT, "decoder.prediction.embed.weight"},
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, "decoder.prediction.dec_rnn.lstm.weight_ih_l%d"},
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, "decoder.prediction.dec_rnn.lstm.weight_hh_l%d"},
{PARAKEET_TENSOR_PRED_LSTM_BIAS_H, "decoder.prediction.dec_rnn.lstm.bias_h_l%d"},
// Joint network
{PARAKEET_TENSOR_JOINT_PRED_WEIGHT, "joint.pred.weight"},
{PARAKEET_TENSOR_JOINT_PRED_BIAS, "joint.pred.bias"},
{PARAKEET_TENSOR_JOINT_ENC_WEIGHT, "joint.enc.weight"},
{PARAKEET_TENSOR_JOINT_ENC_BIAS, "joint.enc.bias"},
{PARAKEET_TENSOR_JOINT_NET_WEIGHT, "joint.joint_net.2.weight"},
{PARAKEET_TENSOR_JOINT_NET_BIAS, "joint.joint_net.2.bias"},
};
static const std::map<parakeet_tensor, ggml_op> PARAKEET_TENSOR_INFO = {
// Encoder pre_encode
{PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, GGML_OP_IM2COL},
{PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, GGML_OP_IM2COL},
{PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, GGML_OP_IM2COL},
{PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, GGML_OP_IM2COL},
{PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, GGML_OP_IM2COL},
{PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, GGML_OP_ADD},
// Encoder layers
{PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, GGML_OP_MUL},
{PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, GGML_OP_MUL},
{PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, GGML_OP_IM2COL},
{PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, GGML_OP_IM2COL},
{PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, GGML_OP_MUL},
{PARAKEET_TENSOR_ENC_CONV_BN_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_CONV_BN_MEAN, GGML_OP_SUB},
{PARAKEET_TENSOR_ENC_CONV_BN_VAR, GGML_OP_DIV},
{PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, GGML_OP_NONE},
{PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, GGML_OP_IM2COL},
{PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, GGML_OP_MUL},
{PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, GGML_OP_MUL},
{PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, GGML_OP_MUL},
{PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, GGML_OP_ADD},
// Prediction network
{PARAKEET_TENSOR_PRED_EMBED_WEIGHT, GGML_OP_GET_ROWS},
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_PRED_LSTM_BIAS_H, GGML_OP_ADD},
// Joint network
{PARAKEET_TENSOR_JOINT_PRED_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_JOINT_PRED_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_JOINT_ENC_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_JOINT_ENC_BIAS, GGML_OP_ADD},
{PARAKEET_TENSOR_JOINT_NET_WEIGHT, GGML_OP_MUL_MAT},
{PARAKEET_TENSOR_JOINT_NET_BIAS, GGML_OP_ADD},
};