2323
2424
2525def _is_input (node : torch .fx .Node , exported_program : ExportedProgram ) -> bool :
26- """
27- Returns True if the node is an input node, i.e. a placeholder or a parameter.
26+ """Returns True if the node is an input node, i.e. a placeholder or a
27+ parameter.
2828 """
2929 return node .op == "placeholder" and not is_param_node (exported_program , node )
3030
@@ -42,12 +42,16 @@ def _is_transpose_conv2d_weight(node: torch.fx.Node) -> bool:
4242
4343
4444class ToTosaMemoryFormatPass (ArmPass ):
45- """
46- Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
47- that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts backend.tosa.TRANSPOSE
48- when a transition between 3D and 4D/5D tensors happen.
49- The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
50- This pass also makes other values aware of spatial dimensions required by future operators by back propogating info as required.
45+ """Annotates each node with a tosa_dim_order.
46+
47+ tosa_dim_order can be seen as a channels-last dim-order that in most cases
48+ will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts
49+ backend.tosa.TRANSPOSE when a transition between 3D and 4D/5D tensors
50+ happen. The annotated tosa_dim_order is used to permute the node's shape
51+ such that it gives a TOSA-compliant shape. This pass also makes other values
52+ aware of spatial dimensions required by future operators by back propogating
53+ info as required.
54+
5155 """
5256
5357 _passes_required_after : Set [Type [ExportPass ]] = set ()
@@ -58,8 +62,7 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None:
5862
5963 @staticmethod
6064 def _channels_last_order (rank : int , spatial_rank : int ) -> tuple [int , ...]:
61- """
62- Compute the permutation of tensor dimensions corresponding to a
65+ """Compute the permutation of tensor dimensions corresponding to a
6366 "channels_last"-style memory layout for an arbitrary tensor rank.
6467
6568 In standard PyTorch convention:
@@ -85,6 +88,7 @@ def _channels_last_order(rank: int, spatial_rank: int) -> tuple[int, ...]:
8588 If `rank <= 2`, the function returns the identity order since there
8689 are no distinct channel/spatial dimensions.
8790 In practice only rank 4+ tensors will reach this function as the dim order should be fixed for those.
91+
8892 """
8993 if rank <= 2 :
9094 return tuple (range (rank ))
@@ -96,11 +100,11 @@ def _channels_last_order(rank: int, spatial_rank: int) -> tuple[int, ...]:
96100
97101 @staticmethod
98102 def _channels_last_inverse_order (rank : int , spatial_rank : int ) -> tuple [int , ...]:
99- """
100- Return the inverse permutation of `_channels_last_order`.
103+ """Return the inverse permutation of `_channels_last_order`.
104+
105+ This provides the axis order needed to map a tensor from "channels_last"
106+ layout back to its original layout.
101107
102- This provides the axis order needed to map a tensor from
103- "channels_last" layout back to its original layout.
104108 """
105109 order = ToTosaMemoryFormatPass ._channels_last_order (rank , spatial_rank )
106110 inverse = [0 ] * rank
@@ -109,15 +113,16 @@ def _channels_last_inverse_order(rank: int, spatial_rank: int) -> tuple[int, ...
109113 return tuple (inverse )
110114
111115 def _initial_spatial_rank (self , node : torch .fx .Node ) -> int :
112- """
113- Infer the initial spatial rank based on the current rank, input node spatial
114- ranks and node target. A spatial dimension includes Height, Width or Depth
115- fields. In most operators this will only ever be Height and Width, but for 3D
116- operators such as conv3d this would contain 3 spatial dims.
116+ """Infer the initial spatial rank based on the current rank, input node
117+ spatial ranks and node target. A spatial dimension includes Height,
118+ Width or Depth fields. In most operators this will only ever be Height
119+ and Width, but for 3D operators such as conv3d this would contain 3
120+ spatial dims.
117121
118122 Spatial rank is the max of any input node spatial ranks and the number of
119123 trailing spatial dims we need to preserve (rank - 2, capped at 3). This
120124 decides which axes must stay channels-last when inserting transposes.
125+
121126 """
122127 tensor = get_first_fake_tensor (node ).data
123128 # Start by assuming 2D when dealing with rank4+ to account for the base case
@@ -150,9 +155,9 @@ def _initial_spatial_rank(self, node: torch.fx.Node) -> int:
150155
151156 @staticmethod
152157 def memory_format_differs (shape , spatial_rank ):
153- """
154- Determine whether a tensor shape would be laid out differently in
155- channels-first ((N)NCHW) versus channels-last ((N)NHWC) memory format.
158+ """Determine whether a tensor shape would be laid out differently in
159+ channels-first ((N)NCHW) versus channels-last ((N)NHWC) memory
160+ format.
156161 """
157162 if len (shape ) <= 2 or spatial_rank <= 0 :
158163 return False
@@ -168,8 +173,7 @@ def memory_format_differs(shape, spatial_rank):
168173 def is_channel_reshape (
169174 input_shape , output_shape , input_spatial_rank , output_spatial_rank
170175 ):
171- """
172- Check whether a reshape touches the logical channel or consolidated
176+ """Check whether a reshape touches the logical channel or consolidated
173177 batch dimensions, which would invalidate dim-order annotations.
174178 """
175179
@@ -202,8 +206,7 @@ def get_batch_prod_dim(shape, spatial_rank):
202206
203207 @staticmethod
204208 def insert_input_transpose (node , input_node , graph_module ):
205- """
206- Ensure an input tensor is converted to channels-last ordering by
209+ """Ensure an input tensor is converted to channels-last ordering by
207210 inserting (or folding) a backend `TRANSPOSE` node.
208211 """
209212 if input_node .target == exir_ops .backend .tosa .TRANSPOSE .default :
@@ -240,8 +243,7 @@ def insert_input_transpose(node, input_node, graph_module):
240243
241244 @staticmethod
242245 def insert_output_transpose (node , graph_module ):
243- """
244- Convert a producer's output to channels-last by appending a backend
246+ """Convert a producer's output to channels-last by appending a backend
245247 `TRANSPOSE` node and rewiring its users.
246248 """
247249
@@ -280,9 +282,9 @@ def insert_output_transpose(node, graph_module):
280282 def _insert_view_transpose (
281283 input_shape , output_shape , node , input_node , graph_module
282284 ):
283- """
284- Insert the necessary input/output transposes around reshapes that cross
285- the (N)NCHW -> (N)NHWC boundary or that touch channel dimensions.
285+ """Insert the necessary input/output transposes around reshapes that
286+ cross the (N)NCHW -> (N)NHWC boundary or that touch channel
287+ dimensions.
286288 """
287289 nchw_to_nhwc = len (input_shape ) < 4 and len (output_shape ) >= 4
288290 nhwc_to_nchw = len (input_shape ) >= 4 and len (output_shape ) < 4
@@ -310,8 +312,10 @@ def _insert_view_transpose(
310312 ToTosaMemoryFormatPass .insert_output_transpose (node , graph_module )
311313
312314 def insert_tosa_transposes (self , graph_module : torch .fx .GraphModule ):
313- """
314- Transposes are needed for operators transforming the input to a different rank, as 4D and 5D-tensors are assumed to be in (N)NHWC-format, whereas all other are in (N)NCHW format.
315+ """Transposes are needed for operators transforming the input to a
316+ different rank, as 4D and 5D-tensors are assumed to be in (N)NHWC-
317+ format, whereas all other are in (N)NCHW format.
318+
315319 This is relevant for the following cases:
316320 - view: <4D -> >=4D
317321 - view: >=4D -> <4D
@@ -321,6 +325,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
321325 - H == W == 1
322326 - C == 1
323327 - 1D/2D tensors
328+
324329 """
325330 for node in graph_module .graph .nodes :
326331 # call_function and placeholder allowed due to
@@ -383,9 +388,8 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
383388 def remove_dim_order_kwargs (
384389 self , graph_module : torch .fx .GraphModule , node : torch .fx .Node
385390 ):
386- """
387- Drop any user-specified `dim_order` keyword arguments so the pass remains
388- the single source of truth for dim-order annotations.
391+ """Drop any user-specified `dim_order` keyword arguments so the pass
392+ remains the single source of truth for dim-order annotations.
389393 """
390394 if node .op != "call_function" :
391395 return
@@ -439,9 +443,8 @@ def call(self, graph_module: torch.fx.GraphModule):
439443 return PassResult (graph_module , True )
440444
441445 def _propagate_spatial_ranks (self , nodes ):
442- """
443- Propagate `tosa_spatial_rank` metadata backwards so earlier nodes learn
444- about upcoming spatial requirements from future ops.
446+ """Propagate `tosa_spatial_rank` metadata backwards so earlier nodes
447+ learn about upcoming spatial requirements from future ops.
445448 """
446449 changed = True
447450 while changed :
0 commit comments