189 lines
10 KiB
C++
189 lines
10 KiB
C++
#pragma once
|
|
|
|
#include "ggml.h"
|
|
|
|
#include <map>
|
|
|
|
enum parakeet_tensor {
|
|
// Encoder pre_encode
|
|
PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_PRE_OUT_BIAS,
|
|
PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS,
|
|
PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS,
|
|
PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS,
|
|
PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS,
|
|
PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS,
|
|
|
|
// Encoder layers (per-layer)
|
|
PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_NORM_FF1_BIAS,
|
|
PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_NORM_CONV_BIAS,
|
|
PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_CONV_BN_BIAS,
|
|
PARAKEET_TENSOR_ENC_CONV_BN_MEAN,
|
|
PARAKEET_TENSOR_ENC_CONV_BN_VAR,
|
|
PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES,
|
|
PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS,
|
|
PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U,
|
|
PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V,
|
|
PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_NORM_FF2_BIAS,
|
|
PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT,
|
|
PARAKEET_TENSOR_ENC_NORM_OUT_BIAS,
|
|
|
|
// Prediction network
|
|
PARAKEET_TENSOR_PRED_EMBED_WEIGHT,
|
|
PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH,
|
|
PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH,
|
|
PARAKEET_TENSOR_PRED_LSTM_BIAS_H,
|
|
|
|
// Joint network
|
|
PARAKEET_TENSOR_JOINT_PRED_WEIGHT,
|
|
PARAKEET_TENSOR_JOINT_PRED_BIAS,
|
|
PARAKEET_TENSOR_JOINT_ENC_WEIGHT,
|
|
PARAKEET_TENSOR_JOINT_ENC_BIAS,
|
|
PARAKEET_TENSOR_JOINT_NET_WEIGHT,
|
|
PARAKEET_TENSOR_JOINT_NET_BIAS,
|
|
};
|
|
|
|
static const std::map<parakeet_tensor, const char *> PARAKEET_TENSOR_NAMES = {
|
|
// Encoder pre_encode
|
|
{PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, "encoder.pre_encode.out.weight"},
|
|
{PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, "encoder.pre_encode.out.bias"},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, "encoder.pre_encode.conv.0.weight"},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, "encoder.pre_encode.conv.0.bias"},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, "encoder.pre_encode.conv.2.weight"},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, "encoder.pre_encode.conv.2.bias"},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, "encoder.pre_encode.conv.3.weight"},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, "encoder.pre_encode.conv.3.bias"},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, "encoder.pre_encode.conv.5.weight"},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, "encoder.pre_encode.conv.5.bias"},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, "encoder.pre_encode.conv.6.weight"},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, "encoder.pre_encode.conv.6.bias"},
|
|
|
|
// Encoder layers (use %d for layer number)
|
|
{PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, "encoder.layers.%d.norm_feed_forward1.weight"},
|
|
{PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, "encoder.layers.%d.norm_feed_forward1.bias"},
|
|
{PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward1.linear1.weight"},
|
|
{PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward1.linear2.weight"},
|
|
{PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, "encoder.layers.%d.norm_conv.weight"},
|
|
{PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, "encoder.layers.%d.norm_conv.bias"},
|
|
{PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, "encoder.layers.%d.conv.pointwise_conv1.weight"},
|
|
{PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, "encoder.layers.%d.conv.depthwise_conv.weight"},
|
|
{PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, "encoder.layers.%d.conv.batch_norm.weight"},
|
|
{PARAKEET_TENSOR_ENC_CONV_BN_BIAS, "encoder.layers.%d.conv.batch_norm.bias"},
|
|
{PARAKEET_TENSOR_ENC_CONV_BN_MEAN, "encoder.layers.%d.conv.batch_norm.running_mean"},
|
|
{PARAKEET_TENSOR_ENC_CONV_BN_VAR, "encoder.layers.%d.conv.batch_norm.running_var"},
|
|
{PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, "encoder.layers.%d.conv.batch_norm.num_batches_tracked"},
|
|
{PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, "encoder.layers.%d.conv.pointwise_conv2.weight"},
|
|
{PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, "encoder.layers.%d.norm_self_att.weight"},
|
|
{PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, "encoder.layers.%d.norm_self_att.bias"},
|
|
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, "encoder.layers.%d.self_attn.pos_bias_u"},
|
|
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, "encoder.layers.%d.self_attn.pos_bias_v"},
|
|
{PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, "encoder.layers.%d.self_attn.linear_q.weight"},
|
|
{PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, "encoder.layers.%d.self_attn.linear_k.weight"},
|
|
{PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, "encoder.layers.%d.self_attn.linear_v.weight"},
|
|
{PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, "encoder.layers.%d.self_attn.linear_out.weight"},
|
|
{PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, "encoder.layers.%d.self_attn.linear_pos.weight"},
|
|
{PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, "encoder.layers.%d.norm_feed_forward2.weight"},
|
|
{PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, "encoder.layers.%d.norm_feed_forward2.bias"},
|
|
{PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward2.linear1.weight"},
|
|
{PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward2.linear2.weight"},
|
|
{PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, "encoder.layers.%d.norm_out.weight"},
|
|
{PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, "encoder.layers.%d.norm_out.bias"},
|
|
|
|
// Prediction network
|
|
{PARAKEET_TENSOR_PRED_EMBED_WEIGHT, "decoder.prediction.embed.weight"},
|
|
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, "decoder.prediction.dec_rnn.lstm.weight_ih_l%d"},
|
|
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, "decoder.prediction.dec_rnn.lstm.weight_hh_l%d"},
|
|
{PARAKEET_TENSOR_PRED_LSTM_BIAS_H, "decoder.prediction.dec_rnn.lstm.bias_h_l%d"},
|
|
|
|
// Joint network
|
|
{PARAKEET_TENSOR_JOINT_PRED_WEIGHT, "joint.pred.weight"},
|
|
{PARAKEET_TENSOR_JOINT_PRED_BIAS, "joint.pred.bias"},
|
|
{PARAKEET_TENSOR_JOINT_ENC_WEIGHT, "joint.enc.weight"},
|
|
{PARAKEET_TENSOR_JOINT_ENC_BIAS, "joint.enc.bias"},
|
|
{PARAKEET_TENSOR_JOINT_NET_WEIGHT, "joint.joint_net.2.weight"},
|
|
{PARAKEET_TENSOR_JOINT_NET_BIAS, "joint.joint_net.2.bias"},
|
|
};
|
|
|
|
static const std::map<parakeet_tensor, ggml_op> PARAKEET_TENSOR_INFO = {
|
|
// Encoder pre_encode
|
|
{PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, GGML_OP_MUL_MAT},
|
|
{PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, GGML_OP_ADD},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, GGML_OP_IM2COL},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, GGML_OP_ADD},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, GGML_OP_IM2COL},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, GGML_OP_ADD},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, GGML_OP_IM2COL},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, GGML_OP_ADD},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, GGML_OP_IM2COL},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, GGML_OP_ADD},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, GGML_OP_IM2COL},
|
|
{PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, GGML_OP_ADD},
|
|
|
|
// Encoder layers
|
|
{PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, GGML_OP_MUL},
|
|
{PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, GGML_OP_ADD},
|
|
{PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, GGML_OP_MUL_MAT},
|
|
{PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, GGML_OP_MUL_MAT},
|
|
{PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, GGML_OP_MUL},
|
|
{PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, GGML_OP_ADD},
|
|
{PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, GGML_OP_IM2COL},
|
|
{PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, GGML_OP_IM2COL},
|
|
{PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, GGML_OP_MUL},
|
|
{PARAKEET_TENSOR_ENC_CONV_BN_BIAS, GGML_OP_ADD},
|
|
{PARAKEET_TENSOR_ENC_CONV_BN_MEAN, GGML_OP_SUB},
|
|
{PARAKEET_TENSOR_ENC_CONV_BN_VAR, GGML_OP_DIV},
|
|
{PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, GGML_OP_NONE},
|
|
{PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, GGML_OP_IM2COL},
|
|
{PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, GGML_OP_MUL},
|
|
{PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, GGML_OP_ADD},
|
|
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, GGML_OP_ADD},
|
|
{PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, GGML_OP_ADD},
|
|
{PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, GGML_OP_MUL_MAT},
|
|
{PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, GGML_OP_MUL_MAT},
|
|
{PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, GGML_OP_MUL_MAT},
|
|
{PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT},
|
|
{PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, GGML_OP_MUL_MAT},
|
|
{PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, GGML_OP_MUL},
|
|
{PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, GGML_OP_ADD},
|
|
{PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, GGML_OP_MUL_MAT},
|
|
{PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, GGML_OP_MUL_MAT},
|
|
{PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, GGML_OP_MUL},
|
|
{PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, GGML_OP_ADD},
|
|
|
|
// Prediction network
|
|
{PARAKEET_TENSOR_PRED_EMBED_WEIGHT, GGML_OP_GET_ROWS},
|
|
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, GGML_OP_MUL_MAT},
|
|
{PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, GGML_OP_MUL_MAT},
|
|
{PARAKEET_TENSOR_PRED_LSTM_BIAS_H, GGML_OP_ADD},
|
|
|
|
// Joint network
|
|
{PARAKEET_TENSOR_JOINT_PRED_WEIGHT, GGML_OP_MUL_MAT},
|
|
{PARAKEET_TENSOR_JOINT_PRED_BIAS, GGML_OP_ADD},
|
|
{PARAKEET_TENSOR_JOINT_ENC_WEIGHT, GGML_OP_MUL_MAT},
|
|
{PARAKEET_TENSOR_JOINT_ENC_BIAS, GGML_OP_ADD},
|
|
{PARAKEET_TENSOR_JOINT_NET_WEIGHT, GGML_OP_MUL_MAT},
|
|
{PARAKEET_TENSOR_JOINT_NET_BIAS, GGML_OP_ADD},
|
|
};
|