whisper.cpp/models/convert-parakeet-to-ggml.py

338 lines
14 KiB
Python
Executable File

#!/usr/bin/env python3
# Convert Parakeet TDT model from NeMo format to ggml format
#
# Usage: python convert-parakeet-to-ggml.py --model parakeet-model.nemo --output-dir output-dir [--use-f32]
#
# The NeMo file is a tar archive containing:
# - model_weights.ckpt (PyTorch checkpoint)
# - model_config.yaml (model configuration)
# - tokenizer files
#
# This script extracts the NeMo archive, loads the model weights and configuration,
# and saves them in ggml format compatible with whisper.cpp.
#
import torch
import argparse
import io
import os
import sys
import struct
import tarfile
import tempfile
import shutil
import yaml
import numpy as np
from pathlib import Path
from typing import Optional
def hz_to_mel(freq):
return 2595.0 * np.log10(1.0 + freq / 700.0)
def mel_to_hz(mel):
return 700.0 * (10.0**(mel / 2595.0) - 1.0)
def extract_nemo_archive(nemo_path, extract_dir):
print(f"Extracting {nemo_path} to {extract_dir}")
with tarfile.open(nemo_path, 'r') as tar:
tar.extractall(path=extract_dir)
print("Extraction complete")
def load_model_config(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return config
def load_tokenizer(extract_dir, config):
tokenizer_model_path = None
tokenizer_vocab_path = None
for file in os.listdir(extract_dir):
if file.endswith('_tokenizer.model'):
tokenizer_model_path = os.path.join(extract_dir, file)
elif file.endswith('tokenizer.vocab'):
tokenizer_vocab_path = os.path.join(extract_dir, file)
if not tokenizer_model_path:
raise FileNotFoundError("Tokenizer model file not found")
if not tokenizer_vocab_path:
raise FileNotFoundError("Tokenizer vocab file not found")
tokens = {}
with open(tokenizer_vocab_path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
parts = line.strip().split('\t')
if len(parts) >= 1:
token = parts[0]
tokens[token.encode('utf-8')] = idx
print(f"Loaded {len(tokens)} tokens from {os.path.basename(tokenizer_vocab_path)}")
if len(tokens) != 8192:
print(f"WARNING: Expected 8192 tokens, got {len(tokens)}")
return tokens
def write_tensor(fout, name, data, use_f16=True, force_f32=False):
if 'pre_encode.conv' in name and 'bias' in name and len(data.shape) == 1:
data = data.reshape(1, -1, 1, 1)
print(f" Reshaped conv bias {name} to {data.shape}")
n_dims = len(data.shape)
ftype = 1 if use_f16 and not force_f32 else 0
if force_f32:
data = data.astype(np.float32)
elif use_f16:
if n_dims < 2 or 'bias' in name or 'norm' in name or \
('pre_encode.conv' in name and n_dims == 4) or \
'depthwise_conv.weight' in name:
data = data.astype(np.float32)
ftype = 0
else:
data = data.astype(np.float16)
else:
data = data.astype(np.float32)
dims_reversed = [data.shape[n_dims - 1 - i] for i in range(n_dims)]
print(f"Processing: {name} {list(data.shape)}, dtype: {data.dtype}, n_dims: {n_dims}, reversed: {dims_reversed}")
name_bytes = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(name_bytes)
data.tofile(fout)
def convert_parakeet_to_ggml(nemo_path, output_dir, use_f16=True, out_name=None):
nemo_path = Path(nemo_path)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Create temporary directory for extraction
with tempfile.TemporaryDirectory() as temp_dir:
extract_nemo_archive(nemo_path, temp_dir)
config_path = os.path.join(temp_dir, 'model_config.yaml')
config = load_model_config(config_path)
print("Model configuration:")
print(f" Sample rate: {config['sample_rate']}")
print(f" Encoder layers: {config['encoder']['n_layers']}")
print(f" Encoder d_model: {config['encoder']['d_model']}")
print(f" Mel features: {config['preprocessor']['features']}")
weights_path = os.path.join(temp_dir, 'model_weights.ckpt')
print(f"\nLoading model weights from {weights_path}")
checkpoint = torch.load(weights_path, map_location='cpu')
# Extract state dict
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
print(f"Loaded {len(state_dict)} tensors")
# Load tokenizer
print("\nLoading tokenizer...")
tokens = load_tokenizer(temp_dir, config)
print(f"Loaded {len(tokens)} tokens")
# Prepare hyperparameters for the Parakeet ggml format.
hparams = {
'n_audio_ctx': 5000,
'n_audio_state': config['encoder']['d_model'],
'n_audio_head': config['encoder']['n_heads'],
'n_audio_layer': config['encoder']['n_layers'],
'n_mels': config['preprocessor']['features'],
'n_fft': config['preprocessor']['n_fft'],
'subsampling_factor': config['encoder']['subsampling_factor'],
'n_subsampling_channels': config['encoder']['subsampling_conv_channels'],
'n_conv_kernel': config['encoder']['conv_kernel_size'],
'n_pred_dim': config['decoder']['prednet']['pred_hidden'],
'n_pred_layers': config['decoder']['prednet']['pred_rnn_layers'],
'n_vocab': config['decoder']['vocab_size'],
'n_tdt_durations': config['model_defaults']['num_tdt_durations'],
'n_max_tokens': config['decoding']['greedy']['max_symbols'],
}
print("\nGGML hyperparameters:")
for key, value in hparams.items():
print(f" {key}: {value}")
# Create output file
if out_name:
fname_out = output_dir / out_name
else:
fname_out = output_dir / ("ggml-model-f32.bin" if not use_f16 else "ggml-model.bin")
print(f"\nWriting to {fname_out}")
with open(fname_out, 'wb') as fout:
# Write magic number
fout.write(struct.pack("i", 0x67676d6c)) # 'ggml' in hex
# Write hyperparameters
fout.write(struct.pack("i", hparams['n_vocab']))
fout.write(struct.pack("i", hparams['n_audio_ctx']))
fout.write(struct.pack("i", hparams['n_audio_state']))
fout.write(struct.pack("i", hparams['n_audio_head']))
fout.write(struct.pack("i", hparams['n_audio_layer']))
fout.write(struct.pack("i", hparams['n_mels']))
fout.write(struct.pack("i", 1 if use_f16 else 0))
fout.write(struct.pack("i", hparams['n_fft']))
fout.write(struct.pack("i", hparams['subsampling_factor']))
fout.write(struct.pack("i", hparams['n_subsampling_channels']))
fout.write(struct.pack("i", hparams['n_conv_kernel']))
fout.write(struct.pack("i", hparams['n_pred_dim']))
fout.write(struct.pack("i", hparams['n_pred_layers']))
fout.write(struct.pack("i", hparams['n_tdt_durations']))
fout.write(struct.pack("i", hparams['n_max_tokens']))
# Extract mel filterbank from model
fb_key = None
for key in state_dict.keys():
if 'featurizer.fb' in key or 'filterbank' in key.lower():
fb_key = key
break
if not fb_key:
print("\nERROR: Mel filterbank not found in model!")
print("Expected tensor with 'featurizer.fb' or 'filterbank' in name")
print("\nAvailable preprocessor tensors:")
for key in sorted(state_dict.keys()):
if 'preprocessor' in key or 'featurizer' in key:
print(f" {key}: {state_dict[key].shape}")
raise ValueError("Mel filterbank tensor not found in model")
print(f"\nUsing model's mel filterbank from: {fb_key}")
mel_filters = state_dict[fb_key].squeeze().numpy().astype(np.float32)
print(f" Filterbank shape: {mel_filters.shape}")
print(f" Filterbank min/max values: {mel_filters.min():.6f} / {mel_filters.max():.6f}")
print(f" Filterbank non-zero elements: {np.count_nonzero(mel_filters)} / {mel_filters.size}")
print(f" First row sum: {mel_filters[0].sum():.6f}")
if len(mel_filters.shape) != 2:
raise ValueError(f"Expected 2D filterbank, got shape {mel_filters.shape}")
n_mels, n_freqs = mel_filters.shape
fout.write(struct.pack("i", n_mels)) # n_mel
fout.write(struct.pack("i", n_freqs)) # n_fb (frequency bins)
# Write mel filterbank
for i in range(n_mels):
for j in range(n_freqs):
fout.write(struct.pack("f", mel_filters[i, j]))
# Extract window function from model
window_key = None
for key in state_dict.keys():
if 'featurizer.window' in key or 'preproc' in key and 'window' in key:
window_key = key
break
if not window_key:
print("\nERROR: Window function not found in model!")
print("Expected tensor with 'featurizer.window' in name")
raise ValueError("Window function tensor not found in model")
print(f"\nUsing model's window function from: {window_key}")
window = state_dict[window_key].squeeze().numpy().astype(np.float32)
print(f" Window shape: {window.shape}")
print(f" Window min/max values: {window.min():.6f} / {window.max():.6f}")
print(f" Window non-zero elements: {np.count_nonzero(window)} / {window.size}")
print(f" Window sum: {window.sum():.6f}")
if len(window.shape) != 1:
raise ValueError(f"Expected 1D window, got shape {window.shape}")
n_window = window.shape[0]
fout.write(struct.pack("i", n_window))
# Write window function
for i in range(n_window):
fout.write(struct.pack("f", window[i]))
# Write TDT durations
tdt_durations = config['model_defaults']['tdt_durations']
if len(tdt_durations) != hparams['n_tdt_durations']:
raise ValueError(f"TDT durations count mismatch: {len(tdt_durations)} vs {hparams['n_tdt_durations']}")
for duration in tdt_durations:
fout.write(struct.pack("I", duration))
fout.write(struct.pack("i", len(tokens)))
for token_bytes, idx in sorted(tokens.items(), key=lambda x: x[1]):
fout.write(struct.pack("i", len(token_bytes)))
fout.write(token_bytes)
# Pre-collect prediction LSTM input-hidden biases so they can be
# folded into the hidden-hidden bias during the main write loop.
lstm_prefix = 'decoder.prediction.dec_rnn.lstm'
pred_bias_ih = {}
for key, t in state_dict.items():
if f'{lstm_prefix}.bias_ih_l' in key:
layer_idx = int(key.rsplit('bias_ih_l', 1)[1])
pred_bias_ih[layer_idx] = t.squeeze().numpy().astype(np.float32)
print("\nConverting model weights...")
for name, tensor in state_dict.items():
# Skip the filterbank and window - already written in preprocessing section
if name == fb_key:
continue
if name == window_key:
continue
# bias_ih is folded into bias_hh below; skip writing it separately
if f'{lstm_prefix}.bias_ih_l' in name:
continue
# Don't squeeze Conv2d weights - they need to preserve all 4 dimensions
if 'conv' in name and 'weight' in name and len(tensor.shape) == 4:
data = tensor.numpy()
else:
data = tensor.squeeze().numpy()
# For prediction LSTM weights/biases:
# Fold bias_ih into bias_hh (bias_ih already skipped above).
# Reorder gates (input, forget, cell, output) from PyTorch layout
# [i, f, g, o] to [i, f, o, g] so the three sigmoid-gated outputs
# (i, f, o) are contiguous.
if name.startswith(f'{lstm_prefix}.'):
if f'{lstm_prefix}.bias_hh_l' in name:
layer_idx = int(name.rsplit('bias_hh_l', 1)[1])
data = data.astype(np.float32) + pred_bias_ih[layer_idx]
name = name.replace('bias_hh_l', 'bias_h_l')
h = data.shape[0] // 4
data = np.concatenate([data[:h], data[h:2*h], data[3*h:], data[2*h:3*h]], axis=0)
write_tensor(fout, name, data, use_f16=use_f16)
print(f"\nConversion complete!")
print(f"Output file: {fname_out}")
print(f"File size: {fname_out.stat().st_size / (1024**2):.2f} MB")
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Convert Parakeet TDT model from NeMo format to ggml format'
)
parser.add_argument('--model', type=str, required=True,
help='Path to Parakeet .nemo model file')
parser.add_argument('--out-dir', type=str, required=True,
help='Directory to write ggml model file')
parser.add_argument('--use-f32', action='store_true', default=False,
help='Use f32 instead of f16 (default: f16)')
parser.add_argument('--out-name', type=str, default=None,
help='Output file name (default: ggml-model.bin or ggml-model-f32.bin)')
args = parser.parse_args()
if not os.path.exists(args.model):
print(f"Error: {args.model} not found")
sys.exit(1)
use_f16 = not args.use_f32
convert_parakeet_to_ggml(args.model, args.out_dir, use_f16, args.out_name)