From ba573929cd31ddea3c77c5dc9caae78da8117123 Mon Sep 17 00:00:00 2001 From: Christopher Albert Date: Tue, 9 Jun 2026 08:34:31 +0200 Subject: [PATCH] coreml : fix --quantize crash for mlprogram format; fix --optimize-ane label (#3868) commit 8b92060 switched ct.convert() to mlprogram, but did not update the --quantize path. quantize_weights() from neural_network.quantization_utils only works with the legacy neuralnetwork format. Running with --quantize crashed with: Exception: MLModel of type mlProgram cannot be loaded just from the model spec object. It also needs the path to the weights file. Fix: pass compute_precision=ct.precision.FLOAT16 into ct.convert() when --quantize is set. This matches the original intent of nbits=16 (F16 storage) without changing the quantization scheme or model accuracy. Also fix the three boolean CLI flags (--encoder-only, --quantize, --optimize-ane) to use a _str_to_bool helper so that both --flag True and --flag False parse correctly. The type=bool form accepted "False" as True because bool("False") == True. Remove the "currently broken" label from --optimize-ane: the ANE path (WhisperANE with Conv2d attention and LayerNormANE) converts and loads correctly with both PyTorch 2.x and coremltools 9.x. --- models/convert-whisper-to-coreml.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/models/convert-whisper-to-coreml.py b/models/convert-whisper-to-coreml.py index 66827b6d4..7cf07754a 100644 --- a/models/convert-whisper-to-coreml.py +++ b/models/convert-whisper-to-coreml.py @@ -8,10 +8,19 @@ from torch import nn from typing import Dict from typing import Optional from ane_transformers.reference.layer_norm import LayerNormANE as LayerNormANEBase -from coremltools.models.neural_network.quantization_utils import quantize_weights from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions from whisper import load_model + +def _str_to_bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("true", "1", "yes"): + return True + if v.lower() in ("false", "0", "no"): + return False + raise argparse.ArgumentTypeError(f"boolean value expected, got '{v}'") + # Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues. # The Whisper implementation expects a specific behavior from # torch.nn.functional.scaled_dot_product_attention that differs between PyTorch @@ -258,11 +267,9 @@ def convert_encoder(hparams, model, quantize=False): inputs=[ct.TensorType(name="logmel_data", shape=input_shape)], outputs=[ct.TensorType(name="output")], compute_units=ct.ComputeUnit.ALL, + compute_precision=ct.precision.FLOAT16 if quantize else ct.precision.FLOAT32, ) - if quantize: - model = quantize_weights(model, nbits=16) - return model def convert_decoder(hparams, model, quantize=False): @@ -283,20 +290,18 @@ def convert_decoder(hparams, model, quantize=False): ct.TensorType(name="token_data", shape=tokens_shape, dtype=int), ct.TensorType(name="audio_data", shape=audio_shape) ], + compute_precision=ct.precision.FLOAT16 if quantize else ct.precision.FLOAT32, ) - if quantize: - model = quantize_weights(model, nbits=16) - return model if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3, large-v3-turbo)", required=True) - parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False) - parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False) - parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False) + parser.add_argument("--encoder-only", type=_str_to_bool, help="only convert encoder", default=False) + parser.add_argument("--quantize", type=_str_to_bool, help="quantize weights to F16", default=False) + parser.add_argument("--optimize-ane", type=_str_to_bool, help="optimize for ANE execution", default=False) args = parser.parse_args() if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large-v1", "large-v2", "large-v3", "large-v3-turbo"]: