coreml : fix --quantize crash for mlprogram format; fix --optimize-ane label (#3868)

commit 8b92060 switched ct.convert() to mlprogram, but did not update
the --quantize path.  quantize_weights() from
neural_network.quantization_utils only works with the legacy
neuralnetwork format.  Running with --quantize crashed with:

  Exception: MLModel of type mlProgram cannot be loaded just from the
  model spec object. It also needs the path to the weights file.

Fix: pass compute_precision=ct.precision.FLOAT16 into ct.convert() when
--quantize is set.  This matches the original intent of nbits=16 (F16
storage) without changing the quantization scheme or model accuracy.

Also fix the three boolean CLI flags (--encoder-only, --quantize,
--optimize-ane) to use a _str_to_bool helper so that both
  --flag True
and
  --flag False
parse correctly.  The type=bool form accepted "False" as True because
bool("False") == True.

Remove the "currently broken" label from --optimize-ane: the ANE path
(WhisperANE with Conv2d attention and LayerNormANE) converts and loads
correctly with both PyTorch 2.x and coremltools 9.x.
This commit is contained in:
Christopher Albert 2026-06-09 08:34:31 +02:00 committed by GitHub
parent 84bd03a438
commit ba573929cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 15 additions and 10 deletions

View File

@ -8,10 +8,19 @@ from torch import nn
from typing import Dict
from typing import Optional
from ane_transformers.reference.layer_norm import LayerNormANE as LayerNormANEBase
from coremltools.models.neural_network.quantization_utils import quantize_weights
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
from whisper import load_model
def _str_to_bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("true", "1", "yes"):
return True
if v.lower() in ("false", "0", "no"):
return False
raise argparse.ArgumentTypeError(f"boolean value expected, got '{v}'")
# Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues.
# The Whisper implementation expects a specific behavior from
# torch.nn.functional.scaled_dot_product_attention that differs between PyTorch
@ -258,11 +267,9 @@ def convert_encoder(hparams, model, quantize=False):
inputs=[ct.TensorType(name="logmel_data", shape=input_shape)],
outputs=[ct.TensorType(name="output")],
compute_units=ct.ComputeUnit.ALL,
compute_precision=ct.precision.FLOAT16 if quantize else ct.precision.FLOAT32,
)
if quantize:
model = quantize_weights(model, nbits=16)
return model
def convert_decoder(hparams, model, quantize=False):
@ -283,20 +290,18 @@ def convert_decoder(hparams, model, quantize=False):
ct.TensorType(name="token_data", shape=tokens_shape, dtype=int),
ct.TensorType(name="audio_data", shape=audio_shape)
],
compute_precision=ct.precision.FLOAT16 if quantize else ct.precision.FLOAT32,
)
if quantize:
model = quantize_weights(model, nbits=16)
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3, large-v3-turbo)", required=True)
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
parser.add_argument("--encoder-only", type=_str_to_bool, help="only convert encoder", default=False)
parser.add_argument("--quantize", type=_str_to_bool, help="quantize weights to F16", default=False)
parser.add_argument("--optimize-ane", type=_str_to_bool, help="optimize for ANE execution", default=False)
args = parser.parse_args()
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large-v1", "large-v2", "large-v3", "large-v3-turbo"]: