diff --git a/backends/qualcomm/_passes/annotate_quant_attrs.py b/backends/qualcomm/_passes/annotate_quant_attrs.py index 6077d51b099..f6b8d3ebace 100644 --- a/backends/qualcomm/_passes/annotate_quant_attrs.py +++ b/backends/qualcomm/_passes/annotate_quant_attrs.py @@ -25,6 +25,12 @@ from .utils import get_quant_attrs +EDGE_CAT_OPS = { + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.concat.default, +} + + class AnnotateQuantAttrs(ExportPass): """ Add "quant_attrs" to graph nodes' meta from the QDQ information @@ -79,11 +85,56 @@ def _find_last_dq_nodes(self, node: torch.fx.node.Node) -> torch.fx.node.Node: return last_dq_nodes + def _is_requant_needed(self, src_attrs: Dict[str, Any], dst_attrs: Dict[str, Any]): + if self.skip_advanced_requant: + return src_attrs[QCOM_DTYPE] != dst_attrs[QCOM_DTYPE] + + return any( + src_attrs[attr] != dst_attrs[attr] + for attr in [ + QCOM_SCALE, + QCOM_ZERO_POINT, + QCOM_QUANT_MIN, + QCOM_QUANT_MAX, + QCOM_DTYPE, + ] + ) + + def _annotate_cat_requant(self, quant_node: torch.fx.Node) -> None: + cat_node = quant_node.args[0] + if cat_node.target not in EDGE_CAT_OPS: + return + + output_q_attrs = get_quant_attrs(self.edge_program, quant_node) + for input_node in cat_node.args[0]: + # only process q->dq->cat + if input_node.target not in dq_ops: + continue + + source_q_node = input_node.args[0] + if source_q_node.target not in q_ops: + continue + + source_q_attrs = get_quant_attrs(self.edge_program, source_q_node) + if not self._is_requant_needed(source_q_attrs, output_q_attrs): + continue + + source_node = source_q_node.args[0] + # check produced before q node is Fx node, as we store metadata on producer + if not isinstance(source_node, torch.fx.Node): + continue + + requant_attrs = output_q_attrs.copy() + requant_attrs[QCOM_ENCODING] = source_q_attrs[QCOM_ENCODING] + source_node.meta.setdefault(QCOM_REQUANTIZE, {}) + source_node.meta[QCOM_REQUANTIZE][cat_node.name] = requant_attrs + def _annotate_requant(self, n): # Record requant attributes: # node1 -> q_ui8 (n) -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> .... # We store {node2: quant_attr in dq_int32} in node1.meta if n.target in q_ops and n.args[0].target not in dq_ops: + self._annotate_cat_requant(n) # for some fixed scale op, there is no need to requantize it if n.args[0].target in self.skip_requant_allowlist: return @@ -96,28 +147,7 @@ def _annotate_requant(self, n): # that has multiple outputs that requires quant attributes. # Determine if requantization is needed based on configuration and attribute mismatch. - is_requant_needed = False - if self.skip_advanced_requant: - # In skip_advanced_requant mode, only consider requant if dtypes differ. - if q_attrs[QCOM_DTYPE] != dq_attrs[QCOM_DTYPE]: - is_requant_needed = True - else: - # In full requant mode, consider requant if any key attribute differs. - # This aims to improve accuracy by adjusting scale, zero_point, etc. - # Users can disable this if it causes regressions. - if any( - q_attrs[attr] != dq_attrs[attr] - for attr in [ - QCOM_SCALE, - QCOM_ZERO_POINT, - QCOM_QUANT_MIN, - QCOM_QUANT_MAX, - QCOM_DTYPE, - ] - ): - is_requant_needed = True - - if is_requant_needed: + if self._is_requant_needed(q_attrs, dq_attrs): dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING] user_node = list(dq_node.users)[0] n.args[0].meta.setdefault(QCOM_REQUANTIZE, {}) diff --git a/backends/qualcomm/qnn_preprocess.py b/backends/qualcomm/qnn_preprocess.py index f423288640c..82fa473e12a 100644 --- a/backends/qualcomm/qnn_preprocess.py +++ b/backends/qualcomm/qnn_preprocess.py @@ -56,6 +56,8 @@ def _build_op_wrappers( graph_module = QnnPassManager().transform_for_preprocess_pipeline( edge_program, use_mha2sha=use_mha2sha ) + from executorch.backends.qualcomm.utils.utils import draw_graph + draw_graph("qnn_preprocess_bad", ".", graph_module) assert graph_module is not None nodes_to_wrappers = defaultdict(dict) diff --git a/backends/qualcomm/quantizer/annotators/htp_rules.py b/backends/qualcomm/quantizer/annotators/htp_rules.py index cd65d02c752..0cb0e2ef0dc 100644 --- a/backends/qualcomm/quantizer/annotators/htp_rules.py +++ b/backends/qualcomm/quantizer/annotators/htp_rules.py @@ -12,10 +12,6 @@ import executorch.backends.qualcomm.builders.qnn_constants as QnnConstants import torch - -from executorch.backends.qualcomm.quantizer.observers.concat_observer import ( - ConcatObserver, -) from executorch.backends.qualcomm.quantizer.qconfig import ( get_16a16w_qnn_ptq_config, get_16a4w_qnn_qat_config, @@ -234,35 +230,23 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]) or not _is_float_tensor(node): return - input_qspec_map, input_nodes = {}, node.args[0] - for input in input_nodes: - input_qspec = input.meta.get(Q_ANNOTATION_KEY, None) - qspec = getattr(input_qspec, "output_qspec", None) - # keep shared qspec here for propagation the data range - # without introducing extra requantizations - if isinstance(qspec, SharedQuantizationSpec): - input_qspec_map[input] = SharedQuantizationSpec(input) - else: - input_qspec_map[input] = quantization_config.input_activation + input_nodes = node.args[0] + assert isinstance(input_nodes, Sequence) + + first_input_node = input_nodes[0] + assert isinstance(first_input_node, Node) + + input_qspec_map = { + first_input_node: quantization_config.input_activation, + } + for input_node in input_nodes[1:]: + assert isinstance(input_node, Node) + if input_node not in input_qspec_map: + input_qspec_map[input_node] = quantization_config.input_activation - output_qspec = QuantizationSpec( - dtype=quantization_config.output_activation.dtype, - qscheme=quantization_config.output_activation.qscheme, - quant_max=quantization_config.output_activation.quant_max, - quant_min=quantization_config.output_activation.quant_min, - observer_or_fake_quant_ctr=ConcatObserver.with_args( - # we need to know the concat node in order to hack all the input observers' data range - # since deep copy of fake tensor (node.meta["val"]) is inhibited - # we could only ship grap & node name and perform postprocess inside observer currently - **{ - "node_name": node.name, - "graph": node.graph, - } - ), - ) node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, - output_qspec=output_qspec, + output_qspec=SharedQuantizationSpec((first_input_node, node)), _annotated=True, ) @@ -295,6 +279,7 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: @register_annotator( [ torch.ops.aten.split_with_sizes.default, + torch.ops.aten.split_with_sizes_copy.default, torch.ops.aten.split.Tensor, torch.ops.aten.chunk.default, ], @@ -1203,14 +1188,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: [torch.ops.aten.pixel_shuffle.default], QnnConstants.OpDepthToSpace.op_name ) class PixelShuffle(GeneralOpDef): - pass + @staticmethod + def annotate(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_share_out(node, quantization_config) @register_annotator( [torch.ops.aten.pixel_unshuffle.default], QnnConstants.OpSpaceToDepth.op_name ) class PixelUnshuffle(GeneralOpDef): - pass + @staticmethod + def annotate(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_share_out(node, quantization_config) @register_annotator( diff --git a/backends/qualcomm/quantizer/annotators/lpai_rules.py b/backends/qualcomm/quantizer/annotators/lpai_rules.py index 60cebfcc5c0..68191c3ad12 100644 --- a/backends/qualcomm/quantizer/annotators/lpai_rules.py +++ b/backends/qualcomm/quantizer/annotators/lpai_rules.py @@ -11,10 +11,6 @@ import executorch.backends.qualcomm.builders.qnn_constants as QnnConstants import torch - -from executorch.backends.qualcomm.quantizer.observers.concat_observer import ( - ConcatObserver, -) from executorch.backends.qualcomm.quantizer.qconfig import ( get_16a16w_qnn_ptq_config, get_16a4w_qnn_qat_config, @@ -180,35 +176,21 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]) or not _is_float_tensor(node): return - input_qspec_map, input_nodes = {}, node.args[0] - for input in input_nodes: - input_qspec = input.meta.get(Q_ANNOTATION_KEY, None) - qspec = getattr(input_qspec, "output_qspec", None) - # keep shared qspec here for propagation the data range - # without introducing extra requantizations - if isinstance(qspec, SharedQuantizationSpec): - input_qspec_map[input] = SharedQuantizationSpec(input) - else: - input_qspec_map[input] = quantization_config.input_activation + input_nodes = node.args[0] + first_input_node = input_nodes[0] + assert isinstance(first_input_node, Node) + + input_qspec_map = { + first_input_node: quantization_config.input_activation, + } + for input_node in input_nodes[1:]: + assert isinstance(input_node, Node) + if input_node not in input_qspec_map: + input_qspec_map[input_node] = quantization_config.input_activation - output_qspec = QuantizationSpec( - dtype=quantization_config.output_activation.dtype, - qscheme=quantization_config.output_activation.qscheme, - quant_max=quantization_config.output_activation.quant_max, - quant_min=quantization_config.output_activation.quant_min, - observer_or_fake_quant_ctr=ConcatObserver.with_args( - # we need to know the concat node in order to hack all the input observers' data range - # since deep copy of fake tensor (node.meta["val"]) is inhibited - # we could only ship grap & node name and perform postprocess inside observer currently - **{ - "node_name": node.name, - "graph": node.graph, - } - ), - ) node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, - output_qspec=output_qspec, + output_qspec=SharedQuantizationSpec((first_input_node, node)), _annotated=True, ) @@ -223,6 +205,7 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: @register_annotator( [ torch.ops.aten.split_with_sizes.default, + torch.ops.aten.split_with_sizes_copy.default, torch.ops.aten.split.Tensor, torch.ops.aten.chunk.default, ], @@ -705,14 +688,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: [torch.ops.aten.pixel_shuffle.default], QnnConstants.OpDepthToSpace.op_name ) class PixelShuffle(GeneralOpDef): - pass + @staticmethod + def annotate(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_share_out(node, quantization_config) @register_annotator( [torch.ops.aten.pixel_unshuffle.default], QnnConstants.OpSpaceToDepth.op_name ) class PixelUnshuffle(GeneralOpDef): - pass + @staticmethod + def annotate(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_share_out(node, quantization_config) @register_annotator( diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index b0120dd2848..6edb1d01d38 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -371,6 +371,14 @@ def forward(self, x, y): return torch.cat((x, y, self.const_tensor), axis=2) +class Cat6(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.cat((x, y), axis=2) + + class CausalMask(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_passes.py b/backends/qualcomm/tests/test_passes.py index 1f007628e61..ee6bcf9d8e5 100644 --- a/backends/qualcomm/tests/test_passes.py +++ b/backends/qualcomm/tests/test_passes.py @@ -7,10 +7,12 @@ ConvertMhaToSha, FoldQDQ, InsertIOQDQ, + InsertRequantize, InsertReshapeForReduceOps, RemoveRedundancy, ) from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype +from executorch.backends.qualcomm.quantizer.rules import Q_ANNOTATION_KEY from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset from executorch.backends.qualcomm.tests.models import TopKandIndex from executorch.backends.qualcomm.utils.utils import ( @@ -22,19 +24,25 @@ from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY from executorch.exir.dialects._ops import ops as exir_ops from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import SharedQuantizationSpec class TestPasses(unittest.TestCase): - def _build_quantized_graph(self): + def _build_quantized_graph(self, module=None, sample_input=None): """Build a quantized graph through AnnotateQuantAttrs + FoldQDQ.""" - class AddModule(torch.nn.Module): - def forward(self, x): - return x + 1 + if module is None: + + class AddModule(torch.nn.Module): + def forward(self, x): + return x + 1 + + module = AddModule() - module = AddModule().eval() - sample_input = (torch.randn(1, 4),) + if sample_input is None: + sample_input = (torch.randn(1, 4),) + module = module.eval() exported = torch.export.export(module, sample_input, strict=True).module() quantizer = QnnQuantizer() quantizer.set_default_quant_config(quant_dtype=QuantDtype.use_8a8w) @@ -102,6 +110,58 @@ def test_insert_io_qdq_no_revisit(self): # one quantize (input) and one dequantize (output) = +2 nodes. self.assertEqual(node_count_after, node_count_before + 2) + def test_insert_requantize_for_mismatched_cat_inputs(self): + class CatRequiresRequant(torch.nn.Module): + def forward(self, x): + first = torch.clamp(x, -0.1, 0.1) + second = x * 10.0 + return torch.cat((first, second), dim=1) + + sample_input = (torch.linspace(-1.0, 1.0, 16).reshape(1, 1, 4, 4),) + gm, _ = self._build_quantized_graph(CatRequiresRequant(), sample_input) + gm = InsertRequantize()(gm).graph_module + + cat_node = next( + n for n in gm.graph.nodes if n.target == exir_ops.edge.aten.cat.default + ) + cat_inputs = cat_node.args[0] + to_copy_target = exir_ops.edge.aten._to_copy.default + + self.assertNotEqual(cat_inputs[0].target, to_copy_target) + self.assertEqual(cat_inputs[1].target, to_copy_target) + + def test_cat_annotation_only_shares_output_with_first_input(self): + class CatModule(torch.nn.Module): + def forward(self, x, y): + return torch.cat((x, y), dim=1) + + sample_input = ( + torch.randn(1, 1, 4, 4), + torch.randn(1, 1, 4, 4), + ) + exported = torch.export.export( + CatModule().eval(), sample_input, strict=True + ).module() + quantizer = QnnQuantizer() + quantizer.set_default_quant_config(quant_dtype=QuantDtype.use_8a8w) + prepared = prepare_pt2e(exported, quantizer) + + cat_node = next( + n for n in prepared.graph.nodes if n.target == torch.ops.aten.cat.default + ) + second_input_node = cat_node.args[0][1] + if second_input_node not in cat_node.meta[Q_ANNOTATION_KEY].input_qspec_map: + second_input_node = second_input_node.args[0] + + self.assertIsInstance( + cat_node.meta[Q_ANNOTATION_KEY].output_qspec, + SharedQuantizationSpec, + ) + self.assertNotIsInstance( + cat_node.meta[Q_ANNOTATION_KEY].input_qspec_map[second_input_node], + SharedQuantizationSpec, + ) + def test_insert_reshape_for_argmax(self): class ArgmaxModule(torch.nn.Module): def forward(self, x): diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 3e236952933..fa7d6eeb49c 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -8,6 +8,7 @@ import itertools import json import logging +import operator import subprocess import sys import tempfile @@ -33,11 +34,14 @@ make_quantizer, setup_common_args_and_variables, ) +from executorch.backends.qualcomm.quantizer.rules import Q_ANNOTATION_KEY from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, QnnExecuTorchHtpPerformanceMode, QnnExecuTorchLpaiTargetEnv, ) + +from executorch.backends.qualcomm.tests.models import Cat2 from executorch.backends.qualcomm.tests.utils import ( convert_pt2e, generate_context_binary, @@ -72,7 +76,6 @@ to_edge_transform_and_lower_to_qnn, update_spill_fill_size, ) - from executorch.backends.qualcomm.tests.models import * # noqa: F403 import os @@ -97,6 +100,7 @@ from executorch.examples.models.wav2letter import Wav2LetterModel from executorch.exir import to_edge from executorch.exir.backend.backend_api import disable_validation +from torchao.quantization.pt2e.quantizer import SharedQuantizationSpec class TestQNNFloatingPointOperator(TestQNN): @@ -1688,12 +1692,16 @@ def test_qnn_backend_permute(self): def test_qnn_backend_pixel_shuffle(self): module = PixelShuffle(2) # noqa: F405 - sample_input = (torch.ones([2, 4, 3, 3]),) + sample_input = ( + torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3), + ) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_pixel_unshuffle(self): module = PixelUnshuffle(2) # noqa: F405 - sample_input = (torch.ones([2, 2, 6, 6]),) + sample_input = ( + torch.arange(2 * 2 * 6 * 6, dtype=torch.float32).reshape(2, 2, 6, 6), + ) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_pow_tensor_scalar(self): @@ -2799,6 +2807,17 @@ def test_qnn_backend_cat(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cat_fixed_input(self): + module = Cat6() # noqa: F405 + sample_input = ( + torch.tensor([[[[-10.0, 2.0], [3.0, 4.0]]]]), + torch.tensor([[[[1.0, 3.0], [8.0, 10.0]]]]), + ) + module = self.get_qdq_module(module, sample_input) + from executorch.backends.qualcomm.utils.utils import draw_graph + draw_graph("qdq_bad", ".", module) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cdist(self): module = CDist() # noqa: F405 sample_input = ( @@ -4229,16 +4248,208 @@ def test_qnn_backend_permute(self): def test_qnn_backend_pixel_shuffle(self): module = PixelShuffle(2) # noqa: F405 - sample_input = (torch.ones([2, 4, 3, 3]),) + sample_input = ( + torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3), + ) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_pixel_unshuffle(self): module = PixelUnshuffle(2) # noqa: F405 - sample_input = (torch.ones([2, 2, 6, 6]),) + sample_input = ( + torch.arange(2 * 2 * 6 * 6, dtype=torch.float32).reshape(2, 2, 6, 6), + ) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def _prepare_module_for_qparam_assertions(self, module, sample_input): + backend = get_backend_type(self.backend) + quantizer = make_quantizer( + quant_dtype=QuantDtype.use_8a8w, + custom_annotations=(), + per_channel_conv=True, + per_channel_linear=False, + per_channel_embedding=False, + backend=backend, + soc_model=self.soc_model, + ) + return prepare_pt2e( + torch.export.export(module, sample_input, strict=True).module(), + quantizer, + ) + + def _assert_prepared_nodes_share_qparams( + self, module, sample_input, target_tokens + ) -> list[torch.fx.Node]: + prepared = self._prepare_module_for_qparam_assertions(module, sample_input) + matching_nodes = [ + node + for node in prepared.graph.nodes + if node.op == "call_function" + and any(target_token in str(node.target) for target_token in target_tokens) + ] + + self.assertGreater( + len(matching_nodes), + 0, + f"Failed to find node matching any of {target_tokens}", + ) + for node in matching_nodes: + self.assertIsInstance( + node.meta[Q_ANNOTATION_KEY].output_qspec, + SharedQuantizationSpec, + ) + + return matching_nodes + + def test_qnn_backend_pixel_shuffle_unshuffle_share_qparams(self): + test_cases = [ + ( + "pixel_shuffle", + PixelShuffle(2), # noqa: F405 + (torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3),), + torch.ops.aten.pixel_shuffle.default, + ), + ( + "pixel_unshuffle", + PixelUnshuffle(2), # noqa: F405 + (torch.arange(2 * 2 * 6 * 6, dtype=torch.float32).reshape(2, 2, 6, 6),), + torch.ops.aten.pixel_unshuffle.default, + ), + ] + + for name, module, sample_input, target in test_cases: + with self.subTest(name=name): + prepared = self._prepare_module_for_qparam_assertions( + module, sample_input + ) + for node in prepared.graph.nodes: + if node.op == "call_function" and node.target == target: + self.assertIsInstance( + node.meta[Q_ANNOTATION_KEY].output_qspec, + SharedQuantizationSpec, + ) + break + else: + self.fail(f"Failed to find {target} in prepared graph") + + def test_qnn_backend_value_preserving_ops_share_qparams(self): + test_cases = [ + ( + "channel_shuffle", + ChannelShuffle(2), # noqa: F405 + (torch.randn(1, 4, 3, 3),), + ("aten.channel_shuffle",), + ), + ( + "permute", + Permute([0, 2, 3, 1]), # noqa: F405 + (torch.randn(2, 3, 4, 5),), + ("aten.permute",), + ), + ( + "pixel_shuffle", + PixelShuffle(2), # noqa: F405 + (torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3),), + ("aten.pixel_shuffle",), + ), + ( + "pixel_unshuffle", + PixelUnshuffle(2), # noqa: F405 + (torch.arange(2 * 2 * 6 * 6, dtype=torch.float32).reshape(2, 2, 6, 6),), + ("aten.pixel_unshuffle",), + ), + ( + "repeat", + Repeat(), # noqa: F405 + (torch.randn(2, 2, 2, 2),), + ("aten.repeat",), + ), + ( + "expand_as", + ExpandAs(), # noqa: F405 + (torch.randn(3, 4),), + ("aten.expand",), + ), + ( + "reshape", + Reshape(), # noqa: F405 + (torch.randn(3, 4),), + ("aten.reshape", "aten.view"), + ), + ] + + for name, module, sample_input, target_tokens in test_cases: + with self.subTest(name=name): + self._assert_prepared_nodes_share_qparams( + module, sample_input, target_tokens + ) + + def test_qnn_backend_split_with_sizes_copy_share_qparams(self): + class SplitWithSizesCopy(torch.nn.Module): + def forward(self, x): + out = torch.ops.aten.split_with_sizes_copy.default(x, [2, 2], 1) + return out[0] + out[1] + + backend = get_backend_type(self.backend) + sample_input = ( + torch.arange(2 * 4 * 3 * 3, dtype=torch.float32).reshape(2, 4, 3, 3), + ) + quantizer = make_quantizer( + quant_dtype=QuantDtype.use_8a8w, + custom_annotations=(), + per_channel_conv=True, + per_channel_linear=False, + per_channel_embedding=False, + backend=backend, + soc_model=self.soc_model, + ) + prepared = prepare_pt2e( + torch.export.export( + SplitWithSizesCopy(), sample_input, strict=True + ).module(), + quantizer, + ) + + getitem_count = 0 + for node in prepared.graph.nodes: + if ( + node.op == "call_function" + and node.target == operator.getitem + and node.args[0].target == torch.ops.aten.split_with_sizes_copy.default + ): + self.assertIsInstance( + node.meta[Q_ANNOTATION_KEY].output_qspec, + SharedQuantizationSpec, + ) + getitem_count += 1 + + self.assertGreater(getitem_count, 0) + + def test_qnn_backend_cat_output_shares_with_first_input(self): + sample_input = ( + torch.arange(2 * 3 * 4 * 5, dtype=torch.float32).reshape(2, 3, 4, 5), + torch.arange(2 * 3 * 4 * 5, dtype=torch.float32).reshape(2, 3, 4, 5), + ) + prepared = self._prepare_module_for_qparam_assertions(Cat2(), sample_input) + + for node in prepared.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.cat.default: + self.assertIsInstance( + node.meta[Q_ANNOTATION_KEY].output_qspec, + SharedQuantizationSpec, + ) + second_input_node = node.args[0][1] + if second_input_node not in node.meta[Q_ANNOTATION_KEY].input_qspec_map: + second_input_node = second_input_node.args[0] + self.assertNotIsInstance( + node.meta[Q_ANNOTATION_KEY].input_qspec_map[second_input_node], + SharedQuantizationSpec, + ) + break + else: + self.fail("Failed to find aten.cat.default in prepared graph") + def test_qnn_backend_pow_tensor_scalar(self): test_comb = [ { diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 481c2b71696..c2615cc961f 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -216,6 +216,7 @@ def setUpClass(cls): ) def _assert_outputs_equal(self, model_output, ref_output): + print("QNN output: ", model_output) self.assertTrue(len(ref_output) == len(model_output)) for i in range(len(ref_output)): if model_output[i].dtype == torch.bool or ref_output[i].dtype == torch.bool: diff --git a/examples/qualcomm/executor_runner/qnn_executor_runner.cpp b/examples/qualcomm/executor_runner/qnn_executor_runner.cpp index 5b531fb27c7..7bf2c6aac71 100644 --- a/examples/qualcomm/executor_runner/qnn_executor_runner.cpp +++ b/examples/qualcomm/executor_runner/qnn_executor_runner.cpp @@ -59,6 +59,11 @@ DEFINE_string( "etdump.etdp", "If etdump generation is enabled an etdump will be written out to this path"); +DEFINE_bool( + enable_etdump, + true, + "Enable ETDump event tracing. Disable for cleaner latency benchmarking."); + DEFINE_bool( dump_intermediate_outputs, false, @@ -385,8 +390,11 @@ int main(int argc, char** argv) { // be used by a single thread at at time, but it can be reused. // ETDumpGen etdump_gen; + auto* event_tracer = (FLAGS_enable_etdump || FLAGS_dump_intermediate_outputs) + ? &etdump_gen + : nullptr; Result method = - program->load_method(method_name, &memory_manager, &etdump_gen); + program->load_method(method_name, &memory_manager, event_tracer); ET_CHECK_MSG( method.ok(), "Loading of method %s failed with status 0x%" PRIx32, @@ -694,7 +702,7 @@ int main(int argc, char** argv) { // Dump the etdump data containing profiling/debugging data to the specified // file. ETDumpResult result = etdump_gen.get_etdump_data(); - if (result.buf != nullptr && result.size > 0) { + if (FLAGS_enable_etdump && result.buf != nullptr && result.size > 0) { ET_LOG( Info, "Write etdump to %s, Size = %zu",