Skip to content

fix(multistep): change stop_target_gradients default to True, consistent with vtrace#161

Open
Sumu004 wants to merge 1 commit into
google-deepmind:mainfrom
Sumu004:fix/stop-target-gradients-default-multistep
Open

fix(multistep): change stop_target_gradients default to True, consistent with vtrace#161
Sumu004 wants to merge 1 commit into
google-deepmind:mainfrom
Sumu004:fix/stop-target-gradients-default-multistep

Conversation

@Sumu004

@Sumu004 Sumu004 commented Jun 4, 2026

Copy link
Copy Markdown

What does this PR do?

All seven return/advantage functions in multistep.py previously defaulted stop_target_gradients=False. This silently allowed gradients to flow through bootstrap targets during RL training — incorrect for standard agents and inconsistent with vtrace.py, where every function defaults to True.

This PR changes the default to True for all seven functions, resolving the four-year-old inconsistency raised in #28.

Fixes #28

Functions changed

lambda_returns                              False → True
n_step_bootstrapped_returns                 False → True
discounted_returns                          False → True
importance_corrected_td_errors              False → True
truncated_generalized_advantage_estimation  False → True   ← the main one from #28
general_off_policy_returns_from_action_values  False → True
general_off_policy_returns_from_q_and_v        False → True

Why True is the right default

In standard RL training (PPO, A3C, SAC, etc.), the bootstrap value v(s_{t+1}) is a target — a fixed estimate used to compute TD errors. Backpropagating through it creates a moving-target instability and violates the standard actor-critic update derivation. stop_target_gradients=False is only needed for meta-gradient methods (e.g., differentiating through the RL update itself), which are a niche use case that should require explicit opt-in.

vtrace.py has always used True as the default for exactly this reason.

Backward compatibility

jax.lax.stop_gradient does not affect forward-pass values — it only affects gradient computation. This change therefore cannot alter numerical results in existing code. Users who were relying on the old default for meta-learning must now pass stop_target_gradients=False explicitly.

Tests

Added StopTargetGradientsDefaultTest covering:

  • Default (True) correctly blocks gradients for lambda_returns, truncated_generalized_advantage_estimation, and n_step_bootstrapped_returns
  • Explicit stop_target_gradients=False still passes gradients (the opt-in meta-gradient path works)
  • Forward values are identical regardless of the flag

Before submitting

AI writing disclosure

  • No AI usage
  • AI-assisted: the PR was written and reviewed by a human; AI tools were used to help navigate the codebase and draft test cases.
  • AI-generated

…g vtrace

All seven return/advantage functions in multistep.py previously defaulted
stop_target_gradients=False. This silently allowed gradients to flow through
bootstrap targets during RL training — incorrect for standard agents and
inconsistent with vtrace.py, where every function has defaulted to True
since the library's inception.

Functions changed:
  lambda_returns, n_step_bootstrapped_returns, discounted_returns,
  importance_corrected_td_errors, truncated_generalized_advantage_estimation,
  general_off_policy_returns_from_action_values,
  general_off_policy_returns_from_q_and_v

The False case is still reachable by passing stop_target_gradients=False
explicitly; it is only needed for meta-gradient methods (a rare use case
that should be opt-in, not the default).

Note: stop_gradient does not affect forward-pass values, so this change
carries no risk of altering numerical results in existing code — it only
affects gradient computation. Users who were relying on the old default
(gradients through targets) for meta-learning must now pass
stop_target_gradients=False explicitly.

Also adds StopTargetGradientsDefaultTest covering:
- Default (True) blocks gradients for lambda_returns, GAE, n_step_returns
- Explicit False still passes gradients (opt-in meta-gradient path)
- Forward values are identical regardless of the flag

Fixes: google-deepmind#28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

stop_target_gradients default should be True in GAE function

1 participant