Skip to content

Commit f5e8123

Browse files
Arm backend: Format docs in backends/arm/test/ops (pytorch#17616)
Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com>
1 parent 218e36a commit f5e8123

48 files changed

Lines changed: 300 additions & 171 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

backends/arm/test/ops/test_adaptive_avg_pool2d.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -173,7 +173,9 @@ def test_adaptive_avg_pool2d_u55_INT(test_module):
173173
@common.parametrize("test_module", u55_test_modules)
174174
@common.XfailIfNoCorstone300
175175
def test_adaptive_avg_pool2d_u55_INT_a16w8(test_module):
176-
"""Test adaptive_avg_pool2d with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
176+
"""Test adaptive_avg_pool2d with 16A8W quantization on U55 (16-bit
177+
activations, 8-bit weights)
178+
"""
177179
model, input_tensor = test_module()
178180
pipeline = EthosU55PipelineINT[input_t](
179181
model,
@@ -202,7 +204,9 @@ def test_adaptive_avg_pool2d_u85_INT(test_module):
202204
@common.parametrize("test_module", test_modules)
203205
@common.XfailIfNoCorstone320
204206
def test_adaptive_avg_pool2d_u85_INT_a16w8(test_module):
205-
"""Test adaptive_avg_pool2d with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
207+
"""Test adaptive_avg_pool2d with 16A8W quantization on U85 (16-bit
208+
activations, 8-bit weights)
209+
"""
206210
model, input_tensor = test_module()
207211
pipeline = EthosU85PipelineINT[input_t](
208212
model,

backends/arm/test/ops/test_add.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,9 @@ def test_add_tensor_vgf_quant(test_data: input_t1):
280280

281281
@common.parametrize("test_data", Add.test_data)
282282
def test_add_tensor_tosa_INT_16a8w(test_data: input_t1):
283-
"""Test add operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
283+
"""Test add operation with 16A8W quantization (16-bit activations, 8-bit
284+
weights)
285+
"""
284286
per_channel_quantization = False
285287

286288
pipeline = TosaPipelineINT[input_t1](
@@ -302,7 +304,9 @@ def test_add_tensor_tosa_INT_16a8w(test_data: input_t1):
302304
@common.parametrize("test_data", Add.test_data)
303305
@common.XfailIfNoCorstone300
304306
def test_add_tensor_u55_INT_16a8w(test_data: input_t1):
305-
"""Test add operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
307+
"""Test add operation with 16A8W quantization on U55 (16-bit activations,
308+
8-bit weights)
309+
"""
306310
per_channel_quantization = False
307311

308312
pipeline = EthosU55PipelineINT[input_t1](
@@ -323,7 +327,9 @@ def test_add_tensor_u55_INT_16a8w(test_data: input_t1):
323327
@common.parametrize("test_data", Add.test_data)
324328
@common.XfailIfNoCorstone320
325329
def test_add_tensor_u85_INT_16a8w(test_data: input_t1):
326-
"""Test add operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
330+
"""Test add operation with 16A8W quantization on U85 (16-bit activations,
331+
8-bit weights)
332+
"""
327333
per_channel_quantization = False
328334

329335
pipeline = EthosU85PipelineINT[input_t1](

backends/arm/test/ops/test_addmm.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,9 @@ def test_addmm_vgf_quant(test_data: input_t1):
190190

191191
@common.parametrize("test_data", test_data_suite)
192192
def test_addmm_16a8w_tosa_INT(test_data: input_t1):
193-
"""Test addmm (FC layer) operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
193+
"""Test addmm (FC layer) operation with 16A8W quantization (16-bit
194+
activations, 8-bit weights)
195+
"""
194196
per_channel_quantization = False
195197

196198
pipeline = TosaPipelineINT[input_t1](
@@ -214,7 +216,9 @@ def test_addmm_16a8w_tosa_INT(test_data: input_t1):
214216
reason="Vela compilation fails with 'Invalid arguments' for int16 addmm operations"
215217
)
216218
def test_addmm_16a8w_u55_INT(test_data: input_t1):
217-
"""Test addmm (FC layer) operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
219+
"""Test addmm (FC layer) operation with 16A8W quantization on U55 (16-bit
220+
activations, 8-bit weights)
221+
"""
218222
per_channel_quantization = False
219223

220224
pipeline = EthosU55PipelineINT[input_t1](
@@ -234,7 +238,9 @@ def test_addmm_16a8w_u55_INT(test_data: input_t1):
234238
@common.parametrize("test_data", test_data_suite)
235239
@common.XfailIfNoCorstone320
236240
def test_addmm_16a8w_u85_INT(test_data: input_t1):
237-
"""Test addmm (FC layer) operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
241+
"""Test addmm (FC layer) operation with 16A8W quantization on U85 (16-bit
242+
activations, 8-bit weights)
243+
"""
238244
per_channel_quantization = False
239245

240246
pipeline = EthosU85PipelineINT[input_t1](

backends/arm/test/ops/test_alias_copy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919

2020

2121
class AliasCopy(torch.nn.Module):
22-
"""
23-
Tests proper handling of alias_copy when used directly.
22+
"""Tests proper handling of alias_copy when used directly.
23+
24+
alias_copy can also appear from PyTorch/ExecuTorch optimizations such as
25+
`x.transpose(0, 0)`. This is optimized to an alias_copy but not before dq/q
26+
operators are added.
2427
25-
alias_copy can also appear from PyTorch/ExecuTorch optimizations
26-
such as `x.transpose(0, 0)`. This is optimized to an alias_copy but
27-
not before dq/q operators are added.
2828
"""
2929

3030
aten_op = "torch.ops.aten.alias_copy.default"

backends/arm/test/ops/test_amax.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,9 @@ def test_amax_tosa_INT_a16w8(test_data: Amax.input_t):
280280
@common.parametrize("test_data", Amax.test_data)
281281
@common.XfailIfNoCorstone320
282282
def test_amax_u85_INT_a16w8(test_data: Amax.input_t):
283-
"""Test amax with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
283+
"""Test amax with 16A8W quantization on U85 (16-bit activations, 8-bit
284+
weights)
285+
"""
284286
data, dim, keep_dims = test_data()
285287
module = Amax(dim, keep_dims)
286288
pipeline = EthosU85PipelineINT[Max.input_t](

backends/arm/test/ops/test_amin.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,9 @@ def test_amin_tosa_INT_a16w8(test_data: Amin.input_t):
276276
@common.parametrize("test_data", Amin.test_data)
277277
@common.XfailIfNoCorstone320
278278
def test_amin_u85_INT_a16w8(test_data: Min.input_t):
279-
"""Test amin with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
279+
"""Test amin with 16A8W quantization on U85 (16-bit activations, 8-bit
280+
weights)
281+
"""
280282
data, dim, keep_dims = test_data()
281283
pipeline = EthosU85PipelineINT[Amin.input_t](
282284
Amin(dim, keep_dims),

backends/arm/test/ops/test_avg_pool2d.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,11 @@ def forward(self, *args, **kwargs):
3535

3636

3737
class BecomesMeanInToEdge(torch.nn.Module):
38-
"""This averagepool will be converted to mean when lowering to edge. This causes the decompose_meandim pass to not
39-
trigger until the backend pipeline, which requires extra care.
38+
"""This averagepool will be converted to mean when lowering to edge.
39+
40+
This causes the decompose_meandim pass to not trigger until the backend
41+
pipeline, which requires extra care.
42+
4043
"""
4144

4245
def forward(self, x: torch.Tensor):
@@ -207,7 +210,9 @@ def test_avg_pool2d_u55_INT(test_module):
207210
@common.parametrize("test_module", test_modules)
208211
@common.XfailIfNoCorstone300
209212
def test_avg_pool2d_16a8w_u55_INT(test_module):
210-
"""Test avg_pool2d with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
213+
"""Test avg_pool2d with 16A8W quantization on U55 (16-bit activations, 8-bit
214+
weights)
215+
"""
211216
model, input_tensor = test_module()
212217
pipeline = EthosU55PipelineINT[input_t](
213218
model,
@@ -238,7 +243,9 @@ def test_avg_pool2d_u85_INT(test_module):
238243
@common.parametrize("test_module", test_modules)
239244
@common.XfailIfNoCorstone320
240245
def test_avg_pool2d_16a8w_u85_INT(test_module):
241-
"""Test avg_pool2d with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
246+
"""Test avg_pool2d with 16A8W quantization on U85 (16-bit activations, 8-bit
247+
weights)
248+
"""
242249
model, input_tensor = test_module()
243250
pipeline = EthosU85PipelineINT[input_t](
244251
model,

backends/arm/test/ops/test_batch_norm.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
3+
# Copyright 2024-2026 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
@@ -266,9 +266,7 @@ def test_native_batch_norm_legit_no_training_vgf_quant_conv(test_data: Tuple):
266266

267267

268268
class BatchNorm2dNoStats(torch.nn.Module):
269-
"""
270-
Decomposes into _native_batch_norm_legit.no_stats
271-
"""
269+
"""Decomposes into _native_batch_norm_legit.no_stats."""
272270

273271
aten_ops = ["torch.ops.aten.batch_norm.default"]
274272

backends/arm/test/ops/test_cat.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,9 @@ def test_cat_vgf_quant(test_data: Tuple):
198198

199199
@common.parametrize("test_data", Cat.test_parameters)
200200
def test_cat_16a8w_tosa_INT(test_data: Tuple):
201-
"""Test cat operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
201+
"""Test cat operation with 16A8W quantization (16-bit activations, 8-bit
202+
weights)
203+
"""
202204
per_channel_quantization = False
203205

204206
pipeline = TosaPipelineINT[input_t1](
@@ -219,7 +221,9 @@ def test_cat_16a8w_tosa_INT(test_data: Tuple):
219221
@common.parametrize("test_data", Cat.test_parameters)
220222
@common.XfailIfNoCorstone300
221223
def test_cat_16a8w_u55_INT(test_data: Tuple):
222-
"""Test cat operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
224+
"""Test cat operation with 16A8W quantization on U55 (16-bit activations,
225+
8-bit weights)
226+
"""
223227
per_channel_quantization = False
224228

225229
pipeline = EthosU55PipelineINT[input_t1](
@@ -240,7 +244,9 @@ def test_cat_16a8w_u55_INT(test_data: Tuple):
240244
@common.parametrize("test_data", Cat.test_parameters)
241245
@common.XfailIfNoCorstone320
242246
def test_cat_16a8w_u85_INT(test_data: Tuple):
243-
"""Test cat operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
247+
"""Test cat operation with 16A8W quantization on U85 (16-bit activations,
248+
8-bit weights)
249+
"""
244250
per_channel_quantization = False
245251

246252
pipeline = EthosU85PipelineINT[input_t1](

backends/arm/test/ops/test_clamp.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,9 @@ def test_clamp_u55_INT(test_data):
175175
@common.parametrize("test_data", test_data_suite)
176176
@common.XfailIfNoCorstone300
177177
def test_clamp_u55_INT_16a8w(test_data):
178-
"""Test clamp operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
178+
"""Test clamp operation with 16A8W quantization on U55 (16-bit activations,
179+
8-bit weights)
180+
"""
179181
input_tensor, min_val, max_val = test_data()
180182
model = Clamp(min_val, max_val)
181183
pipeline = EthosU55PipelineINT[input_t](
@@ -208,7 +210,9 @@ def test_clamp_u85_INT(test_data):
208210
@common.parametrize("test_data", test_data_suite)
209211
@common.XfailIfNoCorstone320
210212
def test_clamp_u85_INT_16a8w(test_data):
211-
"""Test clamp operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
213+
"""Test clamp operation with 16A8W quantization on U85 (16-bit activations,
214+
8-bit weights)
215+
"""
212216
input_tensor, min_val, max_val = test_data()
213217
model = Clamp(min_val, max_val)
214218
pipeline = EthosU85PipelineINT[input_t](
@@ -453,7 +457,9 @@ def test_clamp_u55_INT_tensor(test_data):
453457
@common.parametrize("test_data", test_data_suite_tensor_INT32)
454458
@common.XfailIfNoCorstone300
455459
def test_clamp_u55_INT_16a8w_tensor(test_data):
456-
"""Test clamp operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
460+
"""Test clamp operation with 16A8W quantization on U55 (16-bit activations,
461+
8-bit weights)
462+
"""
457463
input_tensor, min_val, max_val = test_data()
458464
model = Clamp(min_val, max_val)
459465
pipeline = EthosU55PipelineINT[input_t](
@@ -486,7 +492,9 @@ def test_clamp_u85_INT_tensor(test_data):
486492
@common.parametrize("test_data", test_data_suite_tensor_INT32)
487493
@common.XfailIfNoCorstone320
488494
def test_clamp_u85_INT_16a8w_tensor(test_data):
489-
"""Test clamp operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
495+
"""Test clamp operation with 16A8W quantization on U85 (16-bit activations,
496+
8-bit weights)
497+
"""
490498
input_tensor, min_val, max_val = test_data()
491499
model = Clamp(min_val, max_val)
492500
pipeline = EthosU85PipelineINT[input_t](

0 commit comments

Comments
 (0)