coreml : fix --quantize crash for mlprogram format; fix --optimize-ane label (#3868)
commit 8b92060 switched ct.convert() to mlprogram, but did not update
the --quantize path. quantize_weights() from
neural_network.quantization_utils only works with the legacy
neuralnetwork format. Running with --quantize crashed with:
Exception: MLModel of type mlProgram cannot be loaded just from the
model spec object. It also needs the path to the weights file.
Fix: pass compute_precision=ct.precision.FLOAT16 into ct.convert() when
--quantize is set. This matches the original intent of nbits=16 (F16
storage) without changing the quantization scheme or model accuracy.
Also fix the three boolean CLI flags (--encoder-only, --quantize,
--optimize-ane) to use a _str_to_bool helper so that both
--flag True
and
--flag False
parse correctly. The type=bool form accepted "False" as True because
bool("False") == True.
Remove the "currently broken" label from --optimize-ane: the ANE path
(WhisperANE with Conv2d attention and LayerNormANE) converts and loads
correctly with both PyTorch 2.x and coremltools 9.x.
This commit is contained in:
parent
84bd03a438
commit
ba573929cd
|
|
@ -8,10 +8,19 @@ from torch import nn
|
|||
from typing import Dict
|
||||
from typing import Optional
|
||||
from ane_transformers.reference.layer_norm import LayerNormANE as LayerNormANEBase
|
||||
from coremltools.models.neural_network.quantization_utils import quantize_weights
|
||||
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
|
||||
from whisper import load_model
|
||||
|
||||
|
||||
def _str_to_bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("true", "1", "yes"):
|
||||
return True
|
||||
if v.lower() in ("false", "0", "no"):
|
||||
return False
|
||||
raise argparse.ArgumentTypeError(f"boolean value expected, got '{v}'")
|
||||
|
||||
# Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues.
|
||||
# The Whisper implementation expects a specific behavior from
|
||||
# torch.nn.functional.scaled_dot_product_attention that differs between PyTorch
|
||||
|
|
@ -258,11 +267,9 @@ def convert_encoder(hparams, model, quantize=False):
|
|||
inputs=[ct.TensorType(name="logmel_data", shape=input_shape)],
|
||||
outputs=[ct.TensorType(name="output")],
|
||||
compute_units=ct.ComputeUnit.ALL,
|
||||
compute_precision=ct.precision.FLOAT16 if quantize else ct.precision.FLOAT32,
|
||||
)
|
||||
|
||||
if quantize:
|
||||
model = quantize_weights(model, nbits=16)
|
||||
|
||||
return model
|
||||
|
||||
def convert_decoder(hparams, model, quantize=False):
|
||||
|
|
@ -283,20 +290,18 @@ def convert_decoder(hparams, model, quantize=False):
|
|||
ct.TensorType(name="token_data", shape=tokens_shape, dtype=int),
|
||||
ct.TensorType(name="audio_data", shape=audio_shape)
|
||||
],
|
||||
compute_precision=ct.precision.FLOAT16 if quantize else ct.precision.FLOAT32,
|
||||
)
|
||||
|
||||
if quantize:
|
||||
model = quantize_weights(model, nbits=16)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3, large-v3-turbo)", required=True)
|
||||
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
|
||||
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
|
||||
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
|
||||
parser.add_argument("--encoder-only", type=_str_to_bool, help="only convert encoder", default=False)
|
||||
parser.add_argument("--quantize", type=_str_to_bool, help="quantize weights to F16", default=False)
|
||||
parser.add_argument("--optimize-ane", type=_str_to_bool, help="optimize for ANE execution", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large-v1", "large-v2", "large-v3", "large-v3-turbo"]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue