Skip to content

Make ScalelessRMSNorm a torch.nn.RMSNorm; fix SDPACustom view -> reshape (#19376)#19376

Open
navsud wants to merge 1 commit intopytorch:mainfrom
navsud:export-D104258950
Open

Make ScalelessRMSNorm a torch.nn.RMSNorm; fix SDPACustom view -> reshape (#19376)#19376
navsud wants to merge 1 commit intopytorch:mainfrom
navsud:export-D104258950

Conversation

@navsud
Copy link
Copy Markdown
Contributor

@navsud navsud commented May 7, 2026

Summary:

Two related changes that together unblock the QNN export path for VLM/STITO:

(1) ScalelessRMSNorm: re-implement as torch.nn.RMSNorm subclass

ScalelessRMSNorm was previously implemented as a hand-rolled RMS normalization
(decomposed into mean / rsqrt / mul). On the QNN export path, this decomposition
fails to lower for an LLM. Using torch.nn.RMSNorm() directly works.

Re-implement ScalelessRMSNorm as a torch.nn.RMSNorm subclass whose weight is
hardcoded to ones and frozen (requires_grad=False). This keeps the public
interface (ScalelessRMSNorm(dim, eps)) unchanged while letting backends see a
proper RMSNorm op so it composes/decomposes cleanly for QNN.

(2) SDPACustom / QuantizedSDPA: replace .view() with .reshape()

Switching to torch.nn.RMSNorm changes how strides propagate through the export
graph compared to the hand-rolled decomposition, exposing a latent bug in
source_transformation/sdpa.py. The output of torch.ops.llama.custom_sdpa retains
the non-contiguous (transposed) strides of its inputs, so
output.view(bsz, seqlen, self.dim) — which merges the last two dims
(n_heads, head_dim) — fails during torch.export with:

Cannot view a tensor with shape (1, s0, 32, 64) and strides
(2048*s0, 64, 64*s0, 1) as a tensor with shape (1, s0, 2048)

Switching to .reshape() inserts .contiguous() only when needed and matches the pattern already used elsewhere in this file (SDPASimple, SDPAFlex, SDPACoreML, and attention.py).

Reviewed By: billmguo, telgamal-1

Differential Revision: D104258950

@navsud navsud requested a review from lucylq as a code owner May 7, 2026 18:37
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 7, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19376

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 3 Pending

As of commit 07aa7e5 with merge base 91aef57 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 7, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented May 7, 2026

@navsud has exported this pull request. If you are a Meta employee, you can view the originating Diff in D104258950.

@navsud navsud added the release notes: none Do not include this in the release notes label May 7, 2026
@navsud navsud requested a review from billmguo May 7, 2026 20:08
@meta-codesync meta-codesync Bot changed the title Make ScalelessRMSNorm a torch.nn.RMSNorm with frozen ones weight Make ScalelessRMSNorm a torch.nn.RMSNorm; fix SDPACustom view -> reshape May 8, 2026
@navsud navsud force-pushed the export-D104258950 branch from 0900104 to 205bba9 Compare May 8, 2026 00:26
…ape (pytorch#19376)

Summary:
Pull Request resolved: pytorch#19376

Two related changes that together unblock the QNN export path for VLM/STITO:

(1) ScalelessRMSNorm: re-implement as torch.nn.RMSNorm subclass

ScalelessRMSNorm was previously implemented as a hand-rolled RMS normalization
(decomposed into mean / rsqrt / mul). On the QNN export path, this decomposition
fails to lower for an LLM. Using torch.nn.RMSNorm() directly works.

Re-implement ScalelessRMSNorm as a torch.nn.RMSNorm subclass whose weight is
hardcoded to ones and frozen (requires_grad=False). This keeps the public
interface (ScalelessRMSNorm(dim, eps)) unchanged while letting backends see a
proper RMSNorm op so it composes/decomposes cleanly for QNN.

(2) SDPACustom / QuantizedSDPA: replace .view() with .reshape()

Switching to torch.nn.RMSNorm changes how strides propagate through the export
graph compared to the hand-rolled decomposition, exposing a latent bug in
source_transformation/sdpa.py. The output of torch.ops.llama.custom_sdpa retains
the non-contiguous (transposed) strides of its inputs, so
output.view(bsz, seqlen, self.dim) — which merges the last two dims
(n_heads, head_dim) — fails during torch.export with:

    Cannot view a tensor with shape (1, s0, 32, 64) and strides
    (2048*s0, 64, 64*s0, 1) as a tensor with shape (1, s0, 2048)

Switching to .reshape() inserts .contiguous() only when needed and matches the pattern already used elsewhere in this file (SDPASimple, SDPAFlex, SDPACoreML, and attention.py).

Reviewed By: billmguo, telgamal-1

Differential Revision: D104258950
@meta-codesync meta-codesync Bot changed the title Make ScalelessRMSNorm a torch.nn.RMSNorm; fix SDPACustom view -> reshape Make ScalelessRMSNorm a torch.nn.RMSNorm; fix SDPACustom view -> reshape (#19376) May 8, 2026
@navsud navsud force-pushed the export-D104258950 branch from 205bba9 to 07aa7e5 Compare May 8, 2026 00:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported release notes: none Do not include this in the release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants