import argparse import json import os import numpy as np import onnx import onnxruntime from onnxruntime.quantization import QuantFormat, QuantType, StaticQuantConfig, quantize from onnxruntime.quantization.calibrate import CalibrationDataReader, CalibrationMethod class OnnxModelCalibrationDataReader(CalibrationDataReader): def __init__(self, model_path): self.model_dir = os.path.dirname(model_path) data_dirs = [ os.path.join(self.model_dir, a) for a in os.listdir(self.model_dir) if a.startswith("test_data_set_") ] model_inputs = onnxruntime.InferenceSession(model_path).get_inputs() name2tensors = [] for data_dir in data_dirs: name2tensor = {} data_paths = [os.path.join(data_dir, a) for a in sorted(os.listdir(data_dir))] data_ndarrays = [self.read_onnx_pb_data(data_path) for data_path in data_paths] for model_input, data_ndarray in zip(model_inputs, data_ndarrays, strict=False): name2tensor[model_input.name] = data_ndarray name2tensors.append(name2tensor) assert len(name2tensors) == len(data_dirs) assert len(name2tensors[0]) == len(model_inputs) self.calibration_data = iter(name2tensors) def get_next(self) -> dict: """generate the input data dict for ONNXinferenceSession run""" return next(self.calibration_data, None) def read_onnx_pb_data(self, file_pb): tensor = onnx.TensorProto() with open(file_pb, "rb") as f: tensor.ParseFromString(f.read()) ret = onnx.numpy_helper.to_array(tensor) return ret def parse_arguments(): parser = argparse.ArgumentParser(description="The arguments for static quantization") parser.add_argument("-i", "--input_model_path", required=True, help="Path to the input onnx model") parser.add_argument( "-o", "--output_quantized_model_path", required=True, help="Path to the output quantized onnx model" ) parser.add_argument( "--activation_type", choices=["qint8", "quint8", "qint16", "quint16", "qint4", "quint4", "qfloat8e4m3fn"], default="quint8", help="Activation quantization type used", ) parser.add_argument( "--weight_type", choices=["qint8", "quint8", "qint16", "quint16", "qint4", "quint4", "qfloat8e4m3fn"], default="qint8", help="Weight quantization type used", ) parser.add_argument("--enable_subgraph", action="store_true", help="If set, subgraph will be quantized.") parser.add_argument( "--force_quantize_no_input_check", action="store_true", help="By default, some latent operators like maxpool, transpose, do not quantize if their input is not" " quantized already. Setting to True to force such operator always quantize input and so generate" " quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude.", ) parser.add_argument( "--matmul_const_b_only", action="store_true", help="If set, only MatMul with const B will be quantized.", ) parser.add_argument( "--add_qdq_pair_to_weight", action="store_true", help="If set, it remains floating-point weight and inserts both QuantizeLinear/DeQuantizeLinear" " nodes to weight.", ) parser.add_argument( "--dedicated_qdq_pair", action="store_true", help="If set, it will create identical and dedicated QDQ pair for each node.", ) parser.add_argument( "--op_types_to_exclude_output_quantization", nargs="+", default=[], help="If any op type is specified, it won't quantize the output of ops with this specific op types.", ) parser.add_argument( "--calibration_method", default="minmax", choices=["minmax", "entropy", "percentile", "distribution"], help="Calibration method used", ) parser.add_argument("--quant_format", default="qdq", choices=["qdq", "qoperator"], help="Quantization format used") parser.add_argument( "--calib_tensor_range_symmetric", action="store_true", help="If enabled, the final range of tensor during calibration will be explicitly" " set to symmetric to central point 0", ) # TODO: --calib_strided_minmax" # TODO: --calib_moving_average_constant" # TODO: --calib_max_intermediate_outputs" parser.add_argument( "--calib_moving_average", action="store_true", help="If enabled, the moving average of" " the minimum and maximum values will be computed when the calibration method selected is MinMax.", ) parser.add_argument( "--disable_quantize_bias", action="store_true", help="Whether to quantize floating-point biases by solely inserting a DeQuantizeLinear node" " If not set, it remains floating-point bias and does not insert any quantization nodes" " associated with biases.", ) # TODO: Add arguments related to Smooth Quant parser.add_argument( "--use_qdq_contrib_ops", action="store_true", help="If set, the inserted QuantizeLinear and DequantizeLinear ops will have the com.microsoft domain," " which forces use of ONNX Runtime's QuantizeLinear and DequantizeLinear contrib op implementations.", ) parser.add_argument( "--minimum_real_range", type=float, default=0.0001, help="If set to a floating-point value, the calculation of the quantization parameters" " (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax-rmin)" " is less than the specified minimum range, rmax will be set to rmin + MinimumRealRange. This is" " necessary for EPs like QNN that require a minimum floating-point range when determining " " quantization parameters.", ) parser.add_argument( "--qdq_keep_removable_activations", action="store_true", help="If set, removable activations (e.g., Clip or Relu) will not be removed," " and will be explicitly represented in the QDQ model.", ) parser.add_argument( "--qdq_disable_weight_adjust_for_int32_bias", action="store_true", help="If set, QDQ quantizer will not adjust the weight's scale when the bias" " has a scale (input_scale * weight_scale) that is too small.", ) parser.add_argument("--per_channel", action="store_true", help="Whether using per-channel quantization") parser.add_argument( "--nodes_to_quantize", nargs="+", default=None, help="List of nodes names to quantize. When this list is not None only the nodes in this list are quantized.", ) parser.add_argument( "--nodes_to_exclude", nargs="+", default=None, help="List of nodes names to exclude. The nodes in this list will be excluded from quantization when it is not None.", ) parser.add_argument( "--op_per_channel_axis", nargs=2, action="append", metavar=("OP_TYPE", "PER_CHANNEL_AXIS"), default=[], help="Set channel axis for specific op type, for example: --op_per_channel_axis MatMul 1, and it's" " effective only when per channel quantization is supported and per_channel is True. If specific" " op type supports per channel quantization but not explicitly specified with channel axis," " default channel axis will be used.", ) parser.add_argument("--tensor_quant_overrides", help="Set the json file for tensor quantization overrides.") return parser.parse_args() def get_tensor_quant_overrides(file): # TODO: Enhance the function to handle more real cases of json file if not file: return {} with open(file) as f: quant_override_dict = json.load(f) for tensor in quant_override_dict: for enc_dict in quant_override_dict[tensor]: enc_dict["scale"] = np.array(enc_dict["scale"], dtype=np.float32) enc_dict["zero_point"] = np.array(enc_dict["zero_point"]) return quant_override_dict def main(): args = parse_arguments() data_reader = OnnxModelCalibrationDataReader(model_path=args.input_model_path) arg2quant_type = { "qint8": QuantType.QInt8, "quint8": QuantType.QUInt8, "qint16": QuantType.QInt16, "quint16": QuantType.QUInt16, "qint4": QuantType.QInt4, "quint4": QuantType.QUInt4, "qfloat8e4m3fn": QuantType.QFLOAT8E4M3FN, } activation_type = arg2quant_type[args.activation_type] weight_type = arg2quant_type[args.weight_type] qdq_op_type_per_channel_support_to_axis = dict(args.op_per_channel_axis) extra_options = { "EnableSubgraph": args.enable_subgraph, "ForceQuantizeNoInputCheck": args.force_quantize_no_input_check, "MatMulConstBOnly": args.matmul_const_b_only, "AddQDQPairToWeight": args.add_qdq_pair_to_weight, "OpTypesToExcludeOutputQuantization": args.op_types_to_exclude_output_quantization, "DedicatedQDQPair": args.dedicated_qdq_pair, "QDQOpTypePerChannelSupportToAxis": qdq_op_type_per_channel_support_to_axis, "CalibTensorRangeSymmetric": args.calib_tensor_range_symmetric, "CalibMovingAverage": args.calib_moving_average, "QuantizeBias": not args.disable_quantize_bias, "UseQDQContribOps": args.use_qdq_contrib_ops, "MinimumRealRange": args.minimum_real_range, "QDQKeepRemovableActivations": args.qdq_keep_removable_activations, "QDQDisableWeightAdjustForInt32Bias": args.qdq_disable_weight_adjust_for_int32_bias, # Load json file for encoding override "TensorQuantOverrides": get_tensor_quant_overrides(args.tensor_quant_overrides), } arg2calib_method = { "minmax": CalibrationMethod.MinMax, "entropy": CalibrationMethod.Entropy, "percentile": CalibrationMethod.Percentile, "distribution": CalibrationMethod.Distribution, } arg2quant_format = { "qdq": QuantFormat.QDQ, "qoperator": QuantFormat.QOperator, } sqc = StaticQuantConfig( calibration_data_reader=data_reader, calibrate_method=arg2calib_method[args.calibration_method], quant_format=arg2quant_format[args.quant_format], activation_type=activation_type, weight_type=weight_type, op_types_to_quantize=None, nodes_to_quantize=args.nodes_to_quantize, nodes_to_exclude=args.nodes_to_exclude, per_channel=args.per_channel, reduce_range=False, use_external_data_format=False, calibration_providers=None, # Use CPUExecutionProvider extra_options=extra_options, ) quantize(model_input=args.input_model_path, model_output=args.output_quantized_model_path, quant_config=sqc) if __name__ == "__main__": main()