fix(multistep): change stop_target_gradients default to True, consistent with vtrace#161
Open
Sumu004 wants to merge 1 commit into
Open
Conversation
…g vtrace All seven return/advantage functions in multistep.py previously defaulted stop_target_gradients=False. This silently allowed gradients to flow through bootstrap targets during RL training — incorrect for standard agents and inconsistent with vtrace.py, where every function has defaulted to True since the library's inception. Functions changed: lambda_returns, n_step_bootstrapped_returns, discounted_returns, importance_corrected_td_errors, truncated_generalized_advantage_estimation, general_off_policy_returns_from_action_values, general_off_policy_returns_from_q_and_v The False case is still reachable by passing stop_target_gradients=False explicitly; it is only needed for meta-gradient methods (a rare use case that should be opt-in, not the default). Note: stop_gradient does not affect forward-pass values, so this change carries no risk of altering numerical results in existing code — it only affects gradient computation. Users who were relying on the old default (gradients through targets) for meta-learning must now pass stop_target_gradients=False explicitly. Also adds StopTargetGradientsDefaultTest covering: - Default (True) blocks gradients for lambda_returns, GAE, n_step_returns - Explicit False still passes gradients (opt-in meta-gradient path) - Forward values are identical regardless of the flag Fixes: google-deepmind#28
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
All seven return/advantage functions in
multistep.pypreviously defaultedstop_target_gradients=False. This silently allowed gradients to flow through bootstrap targets during RL training — incorrect for standard agents and inconsistent withvtrace.py, where every function defaults toTrue.This PR changes the default to
Truefor all seven functions, resolving the four-year-old inconsistency raised in #28.Fixes #28
Functions changed
Why True is the right default
In standard RL training (PPO, A3C, SAC, etc.), the bootstrap value
v(s_{t+1})is a target — a fixed estimate used to compute TD errors. Backpropagating through it creates a moving-target instability and violates the standard actor-critic update derivation.stop_target_gradients=Falseis only needed for meta-gradient methods (e.g., differentiating through the RL update itself), which are a niche use case that should require explicit opt-in.vtrace.pyhas always usedTrueas the default for exactly this reason.Backward compatibility
jax.lax.stop_gradientdoes not affect forward-pass values — it only affects gradient computation. This change therefore cannot alter numerical results in existing code. Users who were relying on the old default for meta-learning must now passstop_target_gradients=Falseexplicitly.Tests
Added
StopTargetGradientsDefaultTestcovering:True) correctly blocks gradients forlambda_returns,truncated_generalized_advantage_estimation, andn_step_bootstrapped_returnsstop_target_gradients=Falsestill passes gradients (the opt-in meta-gradient path works)Before submitting
AI writing disclosure