Skip to content

Commit ef6ceae

Browse files
authored
NXP backend: added Squeeze support (pytorch#16540)
### Summary adds support for "Squeeze" operator ### Test plan tests can be manually run using `pytest -c /dev/null backends/nxp/tests/` cc @robert-kalmar @MartinPavella
1 parent fc0f06b commit ef6ceae

13 files changed

Lines changed: 323 additions & 232 deletions

backends/nxp/aten_passes/convert_unsqueeze_to_view.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

backends/nxp/aten_passes/decompose_split_to_slices_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def call(self, graph_module: GraphModule) -> Optional[PassResult]:
187187
self._replace_split_with_slices(input_node, split_node, starts, ends, dim)
188188
made_changes = True
189189

190-
self.graph_module.recompile()
191190
self.graph_module.graph.eliminate_dead_code()
191+
self.graph_module.recompile()
192192

193193
return PassResult(self.graph_module, made_changes)

backends/nxp/aten_passes/neutron_aten_pass_manager.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77

88
import torch
99

10-
from executorch.backends.nxp.aten_passes.convert_unsqueeze_to_view import (
11-
ConvertUnsqueezeToViewPass,
12-
)
1310
from executorch.backends.nxp.aten_passes.decompose_split_to_slices_pass import (
1411
DecomposeSplitToSlicesPass,
1512
)
@@ -50,7 +47,6 @@ def _get_default_passes(neutron_target_spec, qat_mode: bool = False) -> list[Pas
5047
RemoveNodesWithKnownOutputs(),
5148
FuseLinearAndAddPass(),
5249
MoveActivationBeforeConcat(neutron_target_spec),
53-
ConvertUnsqueezeToViewPass(),
5450
]
5551

5652
if not qat_mode:
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2026 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import torch
8+
9+
from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from torch._subclasses import FakeTensor, FakeTensorMode
12+
from torch.fx import GraphModule, Node
13+
from torch.fx.passes.infra.pass_base import PassResult
14+
15+
16+
class ConvertReshapingNodesToViewPass(NeutronEdgePass):
17+
"""Replaces:
18+
- 'aten.squeeze.default', 'aten.squeeze.dims' and 'aten.squeeze.dim' with 'aten.view_copy.default'.
19+
20+
x x
21+
│ │
22+
┌──────────────▼──────────────┐ replace with ┌───────────────▼────────────────┐
23+
│ aten.[un]squeeze(x, dim) │ ──────────────► │ aten.view_copy.default(x, S) │
24+
└──────────────┬──────────────┘ └───────────────┬────────────────┘
25+
│ │
26+
▼ ▼
27+
out out
28+
29+
- 'aten.unsqueeze.default' with 'aten.view_copy.default'.
30+
31+
x x
32+
│ │
33+
┌─────────────▼─────────────┐ replace with ┌───────────────▼────────────────┐
34+
│ aten.unsqueeze(x, dim) │ ──────────────► │ aten.view_copy.default(x, S) │
35+
└─────────────┬─────────────┘ └───────────────┬────────────────┘
36+
│ │
37+
▼ ▼
38+
out out
39+
"""
40+
41+
graph_module: GraphModule
42+
43+
@staticmethod
44+
def _is_squeeze(node_: Node) -> bool:
45+
return node_.op == "call_function" and (
46+
node_.target == exir_ops.edge.aten.squeeze_copy.dim
47+
or node_.target == exir_ops.edge.aten.squeeze_copy.dims
48+
or node_.target == exir_ops.edge.aten.squeeze_copy.default
49+
)
50+
51+
@staticmethod
52+
def _is_unsqueeze(node_: Node) -> bool:
53+
return (
54+
node_.op == "call_function"
55+
and node_.target == exir_ops.edge.aten.unsqueeze_copy.default
56+
)
57+
58+
def _create_view_copy_node(self, *view_args) -> Node:
59+
view_target = exir_ops.edge.aten.view_copy.default
60+
view_node = self.graph_module.graph.call_function(view_target, view_args)
61+
62+
view_node.meta["source_fn_stack"] = [
63+
(view_node.name, exir_ops.edge.aten.view_copy.default)
64+
]
65+
66+
x_val = view_args[0].meta["val"]
67+
with FakeTensorMode() as mode:
68+
fake_input = FakeTensor.from_tensor(
69+
torch.empty(x_val.shape, dtype=x_val.dtype), mode
70+
)
71+
output_shape = view_target(fake_input, *view_args[1:]).shape
72+
view_node.meta["val"] = FakeTensor.from_tensor(
73+
torch.empty(output_shape, dtype=x_val.dtype), mode
74+
)
75+
76+
return view_node
77+
78+
def run(self, graph_module: GraphModule) -> PassResult:
79+
self.graph_module = graph_module
80+
81+
for node in list(graph_module.graph.nodes):
82+
if not (self._is_squeeze(node) or self._is_unsqueeze(node)):
83+
continue
84+
85+
input_node = node.all_input_nodes[0]
86+
target_shape = node.meta["val"].shape
87+
88+
with self.graph_module.graph.inserting_after(node):
89+
view_copy_node = self._create_view_copy_node(input_node, target_shape)
90+
91+
node.replace_all_uses_with(view_copy_node)
92+
self.graph_module.graph.erase_node(node)
93+
94+
self.graph_module.graph.eliminate_dead_code()
95+
self.graph_module.recompile()
96+
97+
# Return immediately to avoid traversing a modified graph.
98+
# The parent class will call this pass again.
99+
return PassResult(graph_module, True)
100+
101+
# The graph was not modified.
102+
return PassResult(graph_module, False)

backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 NXP
1+
# Copyright 2025-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -105,6 +105,9 @@ class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
105105
ViewCopy,
106106
],
107107
ViewCopy: [Clone, CloneDimOrder],
108+
Conv: [
109+
ViewCopy, # For 1D conv.
110+
],
108111
}
109112

110113
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
@@ -200,6 +203,7 @@ class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
200203
Relu,
201204
Sigmoid,
202205
Tanh,
206+
ViewCopy, # For 1D conv.
203207
],
204208
ViewCopy: [Clone, CloneDimOrder],
205209
}

backends/nxp/edge_passes/neutron_edge_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from executorch.backends.nxp.edge_passes.convert_reshaping_nodes_to_view import (
7+
ConvertReshapingNodesToViewPass,
8+
)
69
from executorch.backends.nxp.edge_passes.move_auxiliary_operator_into_separate_qdq_cluster_pass import (
710
MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass,
811
MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass,
@@ -21,6 +24,7 @@ def __init__(self, passes: list[NeutronEdgePass] = None):
2124
MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(),
2225
MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(),
2326
RemoveUselessAsStridedCopyNodes(),
27+
ConvertReshapingNodesToViewPass(),
2428
]
2529

2630
super().__init__(

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,14 @@
4848
SigmoidPattern,
4949
SliceTensorPattern,
5050
SoftMaxPattern,
51+
SqueezeDimPattern,
52+
SqueezeDimsPattern,
53+
SqueezePattern,
5154
SubTensorPattern,
5255
TanhInPlacePattern,
5356
TanhPattern,
5457
TransposeIntPattern,
58+
UnsqueezePattern,
5559
UpsampleBilinear2DPattern,
5660
UpsampleNearest2DPattern,
5761
ViewPattern,
@@ -281,10 +285,14 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False)
281285
OpQuantizer(SigmoidPattern(is_qat=is_qat), static_qconfig),
282286
OpQuantizer(SliceTensorPattern(is_qat=is_qat), static_qconfig),
283287
OpQuantizer(SoftMaxPattern(is_qat=is_qat), static_qconfig),
288+
OpQuantizer(SqueezeDimPattern(is_qat=is_qat), static_qconfig),
289+
OpQuantizer(SqueezeDimsPattern(is_qat=is_qat), static_qconfig),
290+
OpQuantizer(SqueezePattern(is_qat=is_qat), static_qconfig),
284291
OpQuantizer(SubTensorPattern(is_qat=is_qat), static_qconfig),
285292
OpQuantizer(TanhPattern(is_qat=is_qat), static_qconfig),
286293
OpQuantizer(TanhInPlacePattern(is_qat=is_qat), static_qconfig),
287294
OpQuantizer(TransposeIntPattern(is_qat=is_qat), static_qconfig),
295+
OpQuantizer(UnsqueezePattern(is_qat=is_qat), static_qconfig),
288296
OpQuantizer(UpsampleBilinear2DPattern(is_qat=is_qat), static_qconfig),
289297
OpQuantizer(UpsampleNearest2DPattern(is_qat=is_qat), static_qconfig),
290298
OpQuantizer(ViewPattern(is_qat=is_qat), static_qconfig),

backends/nxp/quantizer/patterns.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,33 @@ def get_anchors(
972972
)
973973

974974

975+
class SqueezePattern(SharedSpecPattern):
976+
"""
977+
Quantizer for the `aten.squeeze.default` operator.
978+
"""
979+
980+
def partition_types(self):
981+
return [torch.ops.aten.squeeze.default]
982+
983+
984+
class SqueezeDimPattern(SharedSpecPattern):
985+
"""
986+
Quantizer for the `aten.squeeze.dim` operator.
987+
"""
988+
989+
def partition_types(self):
990+
return [torch.ops.aten.squeeze.dim]
991+
992+
993+
class SqueezeDimsPattern(SharedSpecPattern):
994+
"""
995+
Quantizer for the `aten.squeeze.dims` operator.
996+
"""
997+
998+
def partition_types(self):
999+
return [torch.ops.aten.squeeze.dims]
1000+
1001+
9751002
class TanhPattern(QuantizationPattern):
9761003
"""
9771004
Quantizer for Tanh operator.
@@ -1008,6 +1035,13 @@ def get_anchors(
10081035
)
10091036

10101037

1038+
class UnsqueezePattern(SharedSpecPattern):
1039+
"""Quantizer for the `aten.unsqueeze.default` operator."""
1040+
1041+
def partition_types(self):
1042+
return [torch.ops.aten.unsqueeze.default]
1043+
1044+
10111045
class UpsampleBilinear2DPattern(SharedSpecPattern):
10121046
"""
10131047
Quantizer for `aten.upsample_bilinear2d.vec` operator.

backends/nxp/tests/models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,3 +788,15 @@ def forward(self, x, divisor):
788788
# partition 2
789789
x = self.prelu(x)
790790
return x
791+
792+
793+
class SqueezeAddModel(torch.nn.Module):
794+
def __init__(self, dim=None):
795+
super().__init__()
796+
self.dim = dim
797+
798+
def forward(self, x, y):
799+
if self.dim is None:
800+
return torch.squeeze(x + y)
801+
else:
802+
return torch.squeeze(x + y, self.dim)

backends/nxp/tests/test_batch_norm_fusion.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 NXP
1+
# Copyright 2025-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -172,9 +172,7 @@ def test_batch_norm_conv_fusing__full_pipeline__1d(bias: bool):
172172
).exported_program()
173173
nodes = list(edge_program.graph.nodes)
174174

175-
assert (
176-
len(nodes) == 17
177-
) # 1D Conv currently isn't delegated, because it doesn't get quantized.
175+
assert len(nodes) == 13
178176
assert not any(
179177
node.op == "call_function" and "batch_norm" in node.target.__name__
180178
for node in nodes

0 commit comments

Comments
 (0)