From 010cd22f6bae119135d8dd78a29a8dba749b3bfd Mon Sep 17 00:00:00 2001 From: Jash Shah Date: Tue, 24 Feb 2026 15:42:45 -0800 Subject: [PATCH] Fix three bugs: missing raise, wrong RND index, variable typo 1. rnn.py: Add missing `raise` before NotImplementedError -- the exception was being constructed but never raised, causing silent failure when resetting hidden state of done environments with a custom hidden state. 2. ppo.py: Fix wrong index in broadcast_parameters -- when RND is enabled, the predictor was loading model_params[1] (critic state) instead of model_params[2] (its own state), corrupting RND weights during multi-GPU training. 3. cnn_model.py: Fix variable name typo `latend_cnn` -> `latent_cnn`. --- rsl_rl/algorithms/ppo.py | 2 +- rsl_rl/models/cnn_model.py | 4 ++-- rsl_rl/modules/rnn.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/rsl_rl/algorithms/ppo.py b/rsl_rl/algorithms/ppo.py index 0e788972..71e7b3de 100644 --- a/rsl_rl/algorithms/ppo.py +++ b/rsl_rl/algorithms/ppo.py @@ -513,7 +513,7 @@ def broadcast_parameters(self) -> None: self.actor.load_state_dict(model_params[0]) self.critic.load_state_dict(model_params[1]) if self.rnd: - self.rnd.predictor.load_state_dict(model_params[1]) + self.rnd.predictor.load_state_dict(model_params[2]) def reduce_parameters(self) -> None: """Collect gradients from all GPUs and average them. diff --git a/rsl_rl/models/cnn_model.py b/rsl_rl/models/cnn_model.py index f79cf468..034cf46b 100644 --- a/rsl_rl/models/cnn_model.py +++ b/rsl_rl/models/cnn_model.py @@ -119,9 +119,9 @@ def get_latent( latent_1d = super().get_latent(obs) # Process 2D observation groups with CNNs latent_cnn_list = [self.cnns[obs_group](obs[obs_group]) for obs_group in self.obs_groups_2d] - latend_cnn = torch.cat(latent_cnn_list, dim=-1) + latent_cnn = torch.cat(latent_cnn_list, dim=-1) # Concatenate 1D and CNN latents - return torch.cat([latent_1d, latend_cnn], dim=-1) + return torch.cat([latent_1d, latent_cnn], dim=-1) def as_jit(self) -> nn.Module: """Return a version of the model compatible with Torch JIT export.""" diff --git a/rsl_rl/modules/rnn.py b/rsl_rl/modules/rnn.py index 2cdf7158..fe4116a0 100644 --- a/rsl_rl/modules/rnn.py +++ b/rsl_rl/modules/rnn.py @@ -62,7 +62,7 @@ def reset(self, dones: torch.Tensor | None = None, hidden_state: HiddenState = N else: self.hidden_state[..., dones == 1, :] = 0.0 else: - NotImplementedError( + raise NotImplementedError( "Resetting the hidden state of done environments with a custom hidden state is not implemented" )