Skip to content

Commit 1bc2c4e

Browse files
authored
NXP backend: Linear + BatchNorm QAT fusing (pytorch#16623)
### Summary Adds two passes for inserting/removing simulated BatchNorm fusion for QAT training, similarly to how `_fuse_conv_bn_qat` adds simulated Conv+BatchNorm fusion in `prepare_qat_pt2e` function from TorchAO. ### Test plan Added integration tests that covers newly added implementation,
1 parent d130d50 commit 1bc2c4e

10 files changed

Lines changed: 957 additions & 4 deletions

File tree

backends/nxp/aten_passes/fuse_batch_norm_with_linear_pass.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,28 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5-
from typing import Optional
65

76
import torch
87
from torch.export.unflatten import _assign_attr, _AttrKind
98
from torch.fx import GraphModule, Node
109
from torch.fx.passes.infra.pass_base import PassBase, PassResult
1110
from torch.nn.parameter import Parameter
1211
from torch.nn.utils import fuse_linear_bn_weights
12+
from torchao.quantization.pt2e.prepare import _is_activation_post_process_node
13+
14+
15+
def _unwrap_if_fq(node: Node, named_modules: dict):
16+
target_node = node
17+
18+
if _is_activation_post_process_node(node, named_modules):
19+
if len(node.args) >= 1:
20+
target_node = node.args[0]
21+
else:
22+
raise ValueError(
23+
f"FakeQuantize node '{node}' should have at least one argument, but has {len(node.args)}."
24+
)
25+
26+
return target_node
1327

1428

1529
class FuseBatchNormWithLinearPass(PassBase):
@@ -53,7 +67,7 @@ def _get_tensor_constant_from_node(self, graph_module, node) -> Parameter | None
5367
attr_itr = getattr(attr_itr, atom)
5468
return attr_itr
5569

56-
def call(self, graph_module: GraphModule) -> Optional[PassResult]:
70+
def call(self, graph_module: GraphModule) -> PassResult | None:
5771
def _is_batch_norm(node_: Node) -> bool:
5872
return (
5973
node_.op == "call_function"
@@ -76,6 +90,8 @@ def _is_linear(node_: Node):
7690
graph_module, made_changes
7791
) # No batch norm nodes in the model.
7892

93+
named_modules = dict(graph_module.named_modules(remove_duplicate=False))
94+
7995
for node in graph_module.graph.nodes:
8096
if not _is_batch_norm(node):
8197
continue # Not BatchNorm.
@@ -86,11 +102,18 @@ def _is_linear(node_: Node):
86102
continue # Something other than a Linear node comes before the BatchNorm.
87103

88104
linear_node = bn_node.args[0]
89-
linear_weight_node = linear_node.args[1]
90-
linear_bias_node = (
105+
linear_weight_node_or_fq = linear_node.args[1]
106+
linear_bias_node_or_fq = (
91107
linear_node.args[2] if len(linear_node.args) > 2 else None
92108
)
93109

110+
linear_weight_node = _unwrap_if_fq(
111+
linear_weight_node_or_fq, named_modules=named_modules
112+
)
113+
linear_bias_node = _unwrap_if_fq(
114+
linear_bias_node_or_fq, named_modules=named_modules
115+
)
116+
94117
linear_w = self._get_tensor_constant_from_node(
95118
graph_module, linear_weight_node
96119
)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes.add_simulated_linear_bn_fusion_qat_pass import (
2+
AddSimulatedLinearBatchNormFusionQATPass,
3+
)
4+
from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes.remove_simulated_linear_bn_fusion_qat_pass import (
5+
RemoveSimulatedLinearBatchNormFusionQATPass,
6+
)
7+
8+
__all__ = [
9+
"AddSimulatedLinearBatchNormFusionQATPass",
10+
"RemoveSimulatedLinearBatchNormFusionQATPass",
11+
]

backends/nxp/aten_passes/simulated_linear_bn_fusion_passes/add_simulated_linear_bn_fusion_qat_pass.py

Lines changed: 384 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Copyright 2026 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from functools import partial
7+
8+
import torch
9+
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import (
10+
_unwrap_if_fq,
11+
)
12+
from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes.add_simulated_linear_bn_fusion_qat_pass import (
13+
_get_compute_scale_factor_pattern,
14+
_get_linear_weight_preprocess_pattern,
15+
)
16+
from executorch.backends.nxp.backend.graph_utils import is_batch_norm, is_op_node
17+
from torch.fx import GraphModule, Node
18+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
19+
from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
20+
from torchao.quantization.pt2e.qat_utils import _get_aten_graph_module_for_pattern
21+
22+
_is_add = partial(is_op_node, target_op=torch.ops.aten.add.Tensor)
23+
_is_div = partial(is_op_node, target_op=torch.ops.aten.div.Tensor)
24+
_is_linear = partial(is_op_node, target_op=torch.ops.aten.linear.default)
25+
_is_reshape = partial(is_op_node, target_op=torch.ops.aten.reshape)
26+
_is_zeros_like = partial(is_op_node, target_op=torch.ops.aten.zeros_like)
27+
28+
29+
def _is_denorm_pattern(node: Node) -> bool:
30+
if not _is_div(node):
31+
return False
32+
33+
if not hasattr(node, "users"):
34+
return False
35+
36+
div_users = node.users.keys()
37+
if len(list(div_users)) < 1:
38+
return False
39+
40+
if any(is_batch_norm(user) for user in div_users):
41+
return True
42+
43+
return False
44+
45+
46+
def _remove_pattern_from_graph(graph_module: GraphModule, pattern: GraphModule):
47+
matcher = SubgraphMatcher(
48+
pattern.graph,
49+
match_output=False,
50+
match_placeholder=False,
51+
remove_overlapping_matches=True,
52+
ignore_literals=True,
53+
)
54+
matches: list[InternalMatch] = matcher.match(graph_module.graph, node_name_match="")
55+
56+
for match in matches:
57+
last_pattern_node = match.anchors[0]
58+
last_matched_subgraph_node = match.nodes_map[last_pattern_node]
59+
weight = match.placeholder_nodes[0]
60+
61+
last_matched_subgraph_node.replace_all_uses_with(weight)
62+
63+
for node in match.nodes_map.values():
64+
if node not in match.placeholder_nodes:
65+
graph_module.graph.erase_node(node)
66+
67+
68+
def _remove_late_bias_pattern(graph_module: GraphModule, bias_node: Node):
69+
linear_b_users = list(bias_node.users.keys())
70+
71+
if len(linear_b_users) != 2:
72+
return
73+
74+
if _is_zeros_like(linear_b_users[0]):
75+
zeros_node, maybe_reshape_node = linear_b_users
76+
elif _is_zeros_like(linear_b_users[1]):
77+
maybe_reshape_node, zeros_node = linear_b_users
78+
else:
79+
return
80+
81+
if _is_reshape(maybe_reshape_node):
82+
reshape_node = maybe_reshape_node
83+
reshape_users = list(reshape_node.users.keys())
84+
85+
if len(reshape_users) != 1:
86+
return
87+
88+
add_node = reshape_users[0]
89+
else:
90+
# Handles no reshape node when bias is scalar
91+
reshape_node = None
92+
add_node = maybe_reshape_node
93+
94+
if not _is_add(add_node):
95+
return
96+
97+
# Remove zeroed linear bias
98+
zeros_node.replace_all_uses_with(bias_node)
99+
graph_module.graph.erase_node(zeros_node)
100+
101+
# Remove late bias addition
102+
add_node.replace_all_uses_with(add_node.args[0])
103+
graph_module.graph.erase_node(add_node)
104+
105+
if reshape_node:
106+
graph_module.graph.erase_node(reshape_node)
107+
108+
109+
def _remove_denorm_and_late_bias(graph_module: GraphModule):
110+
named_modules = dict(graph_module.named_modules(remove_duplicate=False))
111+
112+
for node in graph_module.graph.nodes:
113+
if not _is_linear(node):
114+
continue
115+
116+
linear_node = node
117+
118+
if len(linear_node.args) < 2:
119+
continue
120+
121+
maybe_linear_bias = linear_node.args[2] if len(linear_node.args) > 2 else None
122+
linear_bias_fq_or_zeros = _unwrap_if_fq(
123+
maybe_linear_bias, named_modules=named_modules
124+
)
125+
has_late_bias = _is_zeros_like(linear_bias_fq_or_zeros)
126+
127+
if has_late_bias:
128+
_remove_late_bias_pattern(
129+
graph_module, bias_node=linear_bias_fq_or_zeros.args[0]
130+
)
131+
132+
for user_node in linear_node.users:
133+
if _is_denorm_pattern(user_node):
134+
if any(is_batch_norm(user) for user in user_node.users.keys()):
135+
user_node.replace_all_uses_with(node)
136+
graph_module.graph.erase_node(user_node)
137+
break
138+
139+
140+
class RemoveSimulatedLinearBatchNormFusionQATPass(PassBase):
141+
"""
142+
In order for QAT to work correctly with fused linear + batch norm operators,
143+
simulated linear + batch norm fusion should be added using AddSimulatedLinearBatchNormFusionQATPass.
144+
145+
After the QAT training, before inserting QDQ nodes, nodes added by the simulated fusion should be removed.
146+
This pass removes all artifacts created by AddSimulatedLinearBatchNormFusionQATPass and reverts
147+
the graph back to the layout before the simulated fusion was applied.
148+
See `add_simulated_linear_bn_fusion_qat_pass.py` for more details.
149+
"""
150+
151+
def call(self, graph_module: GraphModule) -> PassResult | None:
152+
"""
153+
Given a graph of decomposed aten ops, removes nodes corresponding to linear + batch norm fusion.
154+
"""
155+
is_cuda = False
156+
157+
graph_module.graph.eliminate_dead_code()
158+
graph_module.recompile()
159+
160+
_scale_compute_example_inputs = (
161+
torch.randn(1),
162+
torch.randn(1),
163+
)
164+
_preprocess_example_inputs = (
165+
torch.randn(1, 1),
166+
torch.randn(1),
167+
)
168+
169+
scale_pattern = _get_compute_scale_factor_pattern()
170+
scale_match_pattern = _get_aten_graph_module_for_pattern(
171+
pattern=scale_pattern,
172+
example_inputs=_scale_compute_example_inputs,
173+
is_cuda=is_cuda,
174+
)
175+
176+
weight_preprocess_pattern = _get_linear_weight_preprocess_pattern()
177+
weight_preprocess_pattern = _get_aten_graph_module_for_pattern(
178+
pattern=weight_preprocess_pattern,
179+
example_inputs=_preprocess_example_inputs,
180+
is_cuda=is_cuda,
181+
)
182+
183+
_remove_pattern_from_graph(graph_module, pattern=scale_match_pattern)
184+
_remove_pattern_from_graph(graph_module, pattern=weight_preprocess_pattern)
185+
_remove_denorm_and_late_bias(graph_module)
186+
187+
graph_module.graph.eliminate_dead_code()
188+
graph_module.recompile()
189+
190+
return PassResult(graph_module, True)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2026 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.dialects._ops import ops as exir_ops
8+
from torch.fx import Node
9+
10+
batch_norm_target_ops = [
11+
# Aten dialect variants
12+
torch.ops.aten.batch_norm.default,
13+
torch.ops.aten.native_batch_norm.default,
14+
torch.ops.aten._native_batch_norm_legit_no_training.default,
15+
# Edge dialect variants
16+
exir_ops.edge.aten.batch_norm.default,
17+
exir_ops.edge.aten.native_batch_norm.default,
18+
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
19+
]
20+
21+
22+
def is_op_node(node: Node, target_op) -> bool:
23+
if isinstance(target_op, list):
24+
target_ops = target_op
25+
else:
26+
target_ops = [target_op]
27+
28+
return (
29+
node is not None
30+
and hasattr(node, "op")
31+
and node.op == "call_function"
32+
and hasattr(node, "target")
33+
and node.target in target_ops
34+
)
35+
36+
37+
def is_batch_norm(node: Node) -> bool:
38+
return is_op_node(node, batch_norm_target_ops)

backends/nxp/quantizer/patterns.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,16 @@ def get_anchors(
711711
output = []
712712
activation.meta["quantization_annotation"].input_qspec_map = {}
713713

714+
# In order for QAT to be numerically correct, there should be no quantization between
715+
# linear node and batch norm node.
716+
if self.is_qat:
717+
linear_users = linear_node.users
718+
possibly_bn = (
719+
list(linear_users.keys())[0] if len(linear_users) == 1 else None
720+
)
721+
if possibly_bn and _is_batch_norm(possibly_bn):
722+
output = []
723+
714724
return PartitionAnchors(
715725
inputs=[(linear_node, NodeArgsIdx(0))],
716726
weights=[(linear_node, NodeArgsIdx(1))],

backends/nxp/quantizer/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313
from typing import Any, Dict, List, Tuple, Type
1414

1515
import torch
16+
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import (
17+
FuseBatchNormWithLinearPass,
18+
)
19+
from executorch.backends.nxp.aten_passes.simulated_linear_bn_fusion_passes import (
20+
AddSimulatedLinearBatchNormFusionQATPass,
21+
RemoveSimulatedLinearBatchNormFusionQATPass,
22+
)
1623
from torch import fx
1724
from torch._ops import OpOverload
1825
from torch.export import ExportedProgram
@@ -184,12 +191,18 @@ def calibrate_and_quantize(
184191

185192
if is_qat:
186193
m = prepare_qat_pt2e(model, quantizer)
194+
m = AddSimulatedLinearBatchNormFusionQATPass()(m).graph_module
187195
m = move_exported_model_to_eval(m)
188196
else:
189197
m = prepare_pt2e(model, quantizer)
190198

191199
for data in calibration_inputs:
192200
m(*data)
201+
202+
if is_qat:
203+
m = RemoveSimulatedLinearBatchNormFusionQATPass()(m).graph_module
204+
m = FuseBatchNormWithLinearPass()(m).graph_module
205+
193206
m = convert_pt2e(m)
194207

195208
return m

backends/nxp/tests/ir/converter/node_converter/test_mm_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def test_mm_conversion(self, _, use_qat: bool):
6060
exported_program,
6161
input_data,
6262
tfl_model=tflite_flatbuffers_model,
63+
atol=1.0,
6364
)
6465

6566
@parameterized.expand([("QAT", True), ("PTQ", False)])

0 commit comments

Comments
 (0)