Skip to content

Commit 218e36a

Browse files
Arm backend: Format Arm passes (group 7) (pytorch#17615)
Apply docformatter to the final set of passes. Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com>
1 parent 1bc2c4e commit 218e36a

10 files changed

Lines changed: 90 additions & 73 deletions

.lintrunner.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,6 @@ include_patterns = ['backends/arm/**/*.py']
510510
exclude_patterns = [
511511
'third-party/**',
512512
'**/third-party/**',
513-
'backends/arm/_passes/**',
514513
'backends/arm/test/**',
515514
]
516515
command = [

backends/arm/_passes/arm_pass.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def allowed_to_transform(self, meta: NodeMetadata | dict[str, Any]) -> bool:
3838
return not disallow_tfa
3939

4040
def _is_quantized_meta(self, meta: NodeMetadata | dict[str, Any]) -> bool:
41-
"""Return True when meta indicates fully quantized inputs and outputs."""
41+
"""Return True when meta indicates fully quantized inputs and
42+
outputs.
43+
"""
4244
if isinstance(meta, NodeMetadata):
4345
meta_dict = meta.data
4446
else:
@@ -107,9 +109,12 @@ def call_submodule(
107109
def call_shape_operator(
108110
self, op, args: tuple, kwargs: dict, meta: NodeMetadata, update: bool
109111
) -> ProxyValue:
110-
"""
111-
Call operator for shape-producing operators. This function is responsible for marking the output of the operator
112-
with the TosaSpecialDtype of SHAPE, so that later passes can identify it as a shape-producing operator and handle it accordingly.
112+
"""Call operator for shape-producing operators.
113+
114+
This function is responsible for marking the output of the operator with
115+
the TosaSpecialDtype of SHAPE, so that later passes can identify it as a
116+
shape-producing operator and handle it accordingly.
117+
113118
"""
114119
# Copy meta and set TosaSpecialDtype to SHAPE
115120
if not isinstance(meta, NodeMetadata):

backends/arm/_passes/constant_folding_pass.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212

1313

1414
class ConstantFoldingPass(ArmPass):
15-
"""Fold constant subgraphs using torch's export constant folding pass. To be used before to_edge transform."""
15+
"""Fold constant subgraphs using torch's export constant folding pass.
16+
17+
To be used before to_edge transform.
18+
19+
"""
1620

1721
_passes_required_after: Set[Type[ExportPass]] = set()
1822

@@ -31,9 +35,7 @@ class ConstantFoldingPass(ArmPass):
3135
def _ensure_param_attr(
3236
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node | None
3337
) -> bool:
34-
"""
35-
Replaces tensor attributes with parameter attributes.
36-
"""
38+
"""Replaces tensor attributes with parameter attributes."""
3739
if node is None or node.op != "get_attr":
3840
return False
3941
target = node.target

backends/arm/_passes/control_flow_const_inline.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515

1616

1717
class ControlFlowConstInlinePass(ArmPass):
18-
"""
19-
When we lift out each control flow body as its own GraphModule, any scalar constants that were captured in Python become module attributes. FX represents those as get_attr nodes in the
20-
submodule graph. These become getattr nodes submodule graph.
18+
"""When we lift out each control flow body as its own GraphModule, any
19+
scalar constants that were captured in Python become module attributes. FX
20+
represents those as get_attr nodes in the submodule graph. These become
21+
getattr nodes submodule graph.
22+
23+
This pass ensures that Scalar tensors in control flow operation are
24+
converted from getattr operators to expected call_function full ops.
2125
22-
This pass ensures that Scalar tensors in control flow operation are converted from getattr operators to expected call_function full ops.
2326
"""
2427

2528
_passes_required_after: Set[Type[ExportPass]] = set()

backends/arm/_passes/convert_full_like_to_full_pass.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ class ConvertFullLikeToFullPass(ArmPass):
2020
As per the full_like PyTorch documentation, `torch.full_like(input,
2121
fill_value)` is equivalent to:
2222
23-
```
24-
torch.full(
25-
input.size(),
26-
fill_value,
27-
dtype=input.dtype,
28-
layout=input.layout,
29-
device=input.device,
30-
)
31-
```
23+
::
24+
25+
torch.full(
26+
input.size(),
27+
fill_value,
28+
dtype=input.dtype,
29+
layout=input.layout,
30+
device=input.device,
31+
)
3232
3333
Skip layout and device since it's not relevant for our backend.
3434

backends/arm/_passes/decompose_strided_slice_copy_pass.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def _fixup_end(end, dim_size):
4343

4444

4545
class DecomposeStridedSliceCopyPass(ArmPass):
46-
"""
47-
Decompose edge.aten.slice_copy.Tensor with non-unit step into supported ops.
46+
"""Decompose edge.aten.slice_copy.Tensor with non-unit step into supported
47+
ops.
4848
4949
Given:
5050
out = slice_copy(x, dim, start, end, step) with step > 1
@@ -57,6 +57,7 @@ class DecomposeStridedSliceCopyPass(ArmPass):
5757
5) out = view_copy(y3, ...) # collapse the singleton dim
5858
5959
This implements "take every step-th element" using only unit-step slice + reshape.
60+
6061
"""
6162

6263
_passes_required_after: Set[Type[ExportPass]] = set()

backends/arm/_passes/decompose_tril_pass.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,12 @@ def _get_ops(op):
4545

4646

4747
class DecomposeTrilPass(ArmPass):
48-
"""
49-
mask_bool = (row + diagonal) >= col (intended AOT-constant)
50-
out = where(mask_bool, x, 0) (0 is a scalar tensor, broadcasted)
48+
"""Tril decomposition.
49+
50+
Decomposition:
51+
mask_bool = (row + diagonal) >= col (intended AOT-constant)
52+
out = where(mask_bool, x, 0) (0 is a scalar tensor, broadcasted)
53+
5154
"""
5255

5356
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass}

backends/arm/_passes/scalars_to_attribute_pass.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ def _convert_scalar_args(
4141
graph_module: GraphModule,
4242
n: Node,
4343
) -> None:
44-
"""
45-
Convert scalar literal args of targeted_ops in node n of graph_module
44+
"""Convert scalar literal args of targeted_ops in node n of graph_module
4645
into attribute get_attr nodes with registered buffers.
4746
"""
4847
if n.op != "call_function" or n.target not in self.targeted_ops:
@@ -97,8 +96,8 @@ def _convert_scalar_args(
9796
graph_module.graph.erase_node(n)
9897

9998
def handle_control_nodes(self, node: Node, graph_module: GraphModule) -> None:
100-
"""
101-
Apply scalar argument conversion on subgraphs of control-flow nodes.
99+
"""Apply scalar argument conversion on subgraphs of control-flow
100+
nodes.
102101
"""
103102
for _, submodule, _ in get_cond_while_submodules(graph_module):
104103
for submodule_node in submodule.graph.nodes:

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424

2525
def _is_input(node: torch.fx.Node, exported_program: ExportedProgram) -> bool:
26-
"""
27-
Returns True if the node is an input node, i.e. a placeholder or a parameter.
26+
"""Returns True if the node is an input node, i.e. a placeholder or a
27+
parameter.
2828
"""
2929
return node.op == "placeholder" and not is_param_node(exported_program, node)
3030

@@ -42,12 +42,16 @@ def _is_transpose_conv2d_weight(node: torch.fx.Node) -> bool:
4242

4343

4444
class ToTosaMemoryFormatPass(ArmPass):
45-
"""
46-
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
47-
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts backend.tosa.TRANSPOSE
48-
when a transition between 3D and 4D/5D tensors happen.
49-
The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
50-
This pass also makes other values aware of spatial dimensions required by future operators by back propogating info as required.
45+
"""Annotates each node with a tosa_dim_order.
46+
47+
tosa_dim_order can be seen as a channels-last dim-order that in most cases
48+
will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts
49+
backend.tosa.TRANSPOSE when a transition between 3D and 4D/5D tensors
50+
happen. The annotated tosa_dim_order is used to permute the node's shape
51+
such that it gives a TOSA-compliant shape. This pass also makes other values
52+
aware of spatial dimensions required by future operators by back propogating
53+
info as required.
54+
5155
"""
5256

5357
_passes_required_after: Set[Type[ExportPass]] = set()
@@ -58,8 +62,7 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None:
5862

5963
@staticmethod
6064
def _channels_last_order(rank: int, spatial_rank: int) -> tuple[int, ...]:
61-
"""
62-
Compute the permutation of tensor dimensions corresponding to a
65+
"""Compute the permutation of tensor dimensions corresponding to a
6366
"channels_last"-style memory layout for an arbitrary tensor rank.
6467
6568
In standard PyTorch convention:
@@ -85,6 +88,7 @@ def _channels_last_order(rank: int, spatial_rank: int) -> tuple[int, ...]:
8588
If `rank <= 2`, the function returns the identity order since there
8689
are no distinct channel/spatial dimensions.
8790
In practice only rank 4+ tensors will reach this function as the dim order should be fixed for those.
91+
8892
"""
8993
if rank <= 2:
9094
return tuple(range(rank))
@@ -96,11 +100,11 @@ def _channels_last_order(rank: int, spatial_rank: int) -> tuple[int, ...]:
96100

97101
@staticmethod
98102
def _channels_last_inverse_order(rank: int, spatial_rank: int) -> tuple[int, ...]:
99-
"""
100-
Return the inverse permutation of `_channels_last_order`.
103+
"""Return the inverse permutation of `_channels_last_order`.
104+
105+
This provides the axis order needed to map a tensor from "channels_last"
106+
layout back to its original layout.
101107
102-
This provides the axis order needed to map a tensor from
103-
"channels_last" layout back to its original layout.
104108
"""
105109
order = ToTosaMemoryFormatPass._channels_last_order(rank, spatial_rank)
106110
inverse = [0] * rank
@@ -109,15 +113,16 @@ def _channels_last_inverse_order(rank: int, spatial_rank: int) -> tuple[int, ...
109113
return tuple(inverse)
110114

111115
def _initial_spatial_rank(self, node: torch.fx.Node) -> int:
112-
"""
113-
Infer the initial spatial rank based on the current rank, input node spatial
114-
ranks and node target. A spatial dimension includes Height, Width or Depth
115-
fields. In most operators this will only ever be Height and Width, but for 3D
116-
operators such as conv3d this would contain 3 spatial dims.
116+
"""Infer the initial spatial rank based on the current rank, input node
117+
spatial ranks and node target. A spatial dimension includes Height,
118+
Width or Depth fields. In most operators this will only ever be Height
119+
and Width, but for 3D operators such as conv3d this would contain 3
120+
spatial dims.
117121
118122
Spatial rank is the max of any input node spatial ranks and the number of
119123
trailing spatial dims we need to preserve (rank - 2, capped at 3). This
120124
decides which axes must stay channels-last when inserting transposes.
125+
121126
"""
122127
tensor = get_first_fake_tensor(node).data
123128
# Start by assuming 2D when dealing with rank4+ to account for the base case
@@ -150,9 +155,9 @@ def _initial_spatial_rank(self, node: torch.fx.Node) -> int:
150155

151156
@staticmethod
152157
def memory_format_differs(shape, spatial_rank):
153-
"""
154-
Determine whether a tensor shape would be laid out differently in
155-
channels-first ((N)NCHW) versus channels-last ((N)NHWC) memory format.
158+
"""Determine whether a tensor shape would be laid out differently in
159+
channels-first ((N)NCHW) versus channels-last ((N)NHWC) memory
160+
format.
156161
"""
157162
if len(shape) <= 2 or spatial_rank <= 0:
158163
return False
@@ -168,8 +173,7 @@ def memory_format_differs(shape, spatial_rank):
168173
def is_channel_reshape(
169174
input_shape, output_shape, input_spatial_rank, output_spatial_rank
170175
):
171-
"""
172-
Check whether a reshape touches the logical channel or consolidated
176+
"""Check whether a reshape touches the logical channel or consolidated
173177
batch dimensions, which would invalidate dim-order annotations.
174178
"""
175179

@@ -202,8 +206,7 @@ def get_batch_prod_dim(shape, spatial_rank):
202206

203207
@staticmethod
204208
def insert_input_transpose(node, input_node, graph_module):
205-
"""
206-
Ensure an input tensor is converted to channels-last ordering by
209+
"""Ensure an input tensor is converted to channels-last ordering by
207210
inserting (or folding) a backend `TRANSPOSE` node.
208211
"""
209212
if input_node.target == exir_ops.backend.tosa.TRANSPOSE.default:
@@ -240,8 +243,7 @@ def insert_input_transpose(node, input_node, graph_module):
240243

241244
@staticmethod
242245
def insert_output_transpose(node, graph_module):
243-
"""
244-
Convert a producer's output to channels-last by appending a backend
246+
"""Convert a producer's output to channels-last by appending a backend
245247
`TRANSPOSE` node and rewiring its users.
246248
"""
247249

@@ -280,9 +282,9 @@ def insert_output_transpose(node, graph_module):
280282
def _insert_view_transpose(
281283
input_shape, output_shape, node, input_node, graph_module
282284
):
283-
"""
284-
Insert the necessary input/output transposes around reshapes that cross
285-
the (N)NCHW -> (N)NHWC boundary or that touch channel dimensions.
285+
"""Insert the necessary input/output transposes around reshapes that
286+
cross the (N)NCHW -> (N)NHWC boundary or that touch channel
287+
dimensions.
286288
"""
287289
nchw_to_nhwc = len(input_shape) < 4 and len(output_shape) >= 4
288290
nhwc_to_nchw = len(input_shape) >= 4 and len(output_shape) < 4
@@ -310,8 +312,10 @@ def _insert_view_transpose(
310312
ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module)
311313

312314
def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
313-
"""
314-
Transposes are needed for operators transforming the input to a different rank, as 4D and 5D-tensors are assumed to be in (N)NHWC-format, whereas all other are in (N)NCHW format.
315+
"""Transposes are needed for operators transforming the input to a
316+
different rank, as 4D and 5D-tensors are assumed to be in (N)NHWC-
317+
format, whereas all other are in (N)NCHW format.
318+
315319
This is relevant for the following cases:
316320
- view: <4D -> >=4D
317321
- view: >=4D -> <4D
@@ -321,6 +325,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
321325
- H == W == 1
322326
- C == 1
323327
- 1D/2D tensors
328+
324329
"""
325330
for node in graph_module.graph.nodes:
326331
# call_function and placeholder allowed due to
@@ -383,9 +388,8 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
383388
def remove_dim_order_kwargs(
384389
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
385390
):
386-
"""
387-
Drop any user-specified `dim_order` keyword arguments so the pass remains
388-
the single source of truth for dim-order annotations.
391+
"""Drop any user-specified `dim_order` keyword arguments so the pass
392+
remains the single source of truth for dim-order annotations.
389393
"""
390394
if node.op != "call_function":
391395
return
@@ -439,9 +443,8 @@ def call(self, graph_module: torch.fx.GraphModule):
439443
return PassResult(graph_module, True)
440444

441445
def _propagate_spatial_ranks(self, nodes):
442-
"""
443-
Propagate `tosa_spatial_rank` metadata backwards so earlier nodes learn
444-
about upcoming spatial requirements from future ops.
446+
"""Propagate `tosa_spatial_rank` metadata backwards so earlier nodes
447+
learn about upcoming spatial requirements from future ops.
445448
"""
446449
changed = True
447450
while changed:

backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -15,9 +15,11 @@
1515

1616

1717
class UnsqueezeScalarPlaceholdersPass(ArmPass):
18-
"""
19-
Placeholders that have node.meta["val"].shape = () cause issues later in the lowering.
18+
"""Placeholders that have node.meta["val"].shape = () cause issues later in
19+
the lowering.
20+
2021
This pass unsqueezes the placeholders to make sure shape is at least (1,).
22+
2123
"""
2224

2325
_passes_required_after: Set[Type[ExportPass]] = set()

0 commit comments

Comments
 (0)