Skip to content

feat(async_inference): add configurable similarity function for observation filtering#3484

Open
AriannaPietrasanta wants to merge 1 commit intohuggingface:mainfrom
gbionics:similarity_fn_selection
Open

feat(async_inference): add configurable similarity function for observation filtering#3484
AriannaPietrasanta wants to merge 1 commit intohuggingface:mainfrom
gbionics:similarity_fn_selection

Conversation

@AriannaPietrasanta
Copy link
Copy Markdown

@AriannaPietrasanta AriannaPietrasanta commented Apr 29, 2026

Title

feat(async_inference): add configurable similarity function for observation filtering

Summary / Motivation

On high-DOF real-robot deployments (30+ joints), the async inference server's hardcoded Euclidean distance filter was too aggressive: because joint-space distances across many dimensions sum to a large norm even for small per-joint deltas, incoming observations were almost always flagged as "similar" to the previous one and skipped. In practice this meant inference was only re-triggered once the action queue was fully depleted, causing jerky, reactive behaviour instead of smooth, continuously-updated control.

This PR introduces a SIMILARITY_FUNCTIONS registry in configs.py and a --similarity_fn_name CLI parameter (passed from the robot client to the policy server via SendPolicyInstructions). Built-in options are euclidean (default, same behaviour as before) and disabled (always run inference). Users can add custom functions directly to the registry. As a side-effect, the Queue(maxsize=1) observation buffer in policy_server.py is replaced with a lock + threading.Event approach that closes a race condition where a new observation could be accepted while the previous inference result was still in transit to the client.

Related issues

None.

What changed

  • configs.py: new SIMILARITY_FUNCTIONS dict and get_similarity_function() factory; similarity_fn_name field added to PolicyServerConfig and RobotClientConfig; field propagated in to_server_config_dict().
  • helpers.py: _compare_observation_states() and observations_similar() now accept a similarity_fn callable instead of hardcoding L2 norm; RemotePolicyConfig gets similarity_fn_name field.
  • policy_server.py: Queue(maxsize=1) replaced with _state_lock + _pending_obs + _inference_in_progress flag + _obs_available Event for race-free obs handoff; similarity function is initialised from the SendPolicyInstructions payload; GetActions atomically grabs the pending obs and sets the in-progress flag; the error path resets the flag.
  • docs/source/async.mdx: --similarity_fn_name documented in both CLI examples and the parameter description list.

How was this tested (or how to run locally)

  • Tests added/updated: tests/async_inference/test_helpers.py, tests/async_inference/test_policy_server.py
  • pytest -q tests/async_inference/
  • Tests on the real robot.

Checklist (required before merge)

  • Linting/formatting run (pre-commit run -a)
  • All tests pass locally (pytest)
  • Documentation updated
  • CI is green
  • Community Review: I have reviewed another contributor's open PR and linked it here: #

Reviewer notes

  • The race-condition fix in policy_server.py is the most impactful change; please focus there. The _inference_in_progress flag is intentionally kept True between inference completion and the client's next GetActions call, so that stale observations are not buffered while actions are in flight.
  • The similarity function selection is entirely backward-compatible: euclidean is the default and reproduces the previous behaviour exactly.

@AriannaPietrasanta AriannaPietrasanta marked this pull request as ready for review April 29, 2026 15:59
Copilot AI review requested due to automatic review settings April 29, 2026 15:59
@github-actions github-actions Bot added documentation Improvements or fixes to the project’s docs tests Problems with test coverage, failures, or improvements to testing labels Apr 29, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a configurable observation-similarity function to the async inference stack so observation filtering can be tuned (or disabled) for high-DOF robots, and refactors the policy server’s observation handoff to avoid races while actions are in flight.

Changes:

  • Introduces a SIMILARITY_FUNCTIONS registry + get_similarity_function() and propagates similarity_fn_name through client/server configs.
  • Updates observation similarity checking to use an injected similarity callable instead of a hardcoded L2 norm.
  • Replaces the server’s Queue(maxsize=1) with a lock + Event + in-progress flag for atomic observation consumption and buffering rules; updates tests/docs accordingly.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
src/lerobot/async_inference/configs.py Adds similarity function registry + config plumbing for similarity_fn_name.
src/lerobot/async_inference/helpers.py Refactors similarity checks to accept a similarity_fn callable; extends RemotePolicyConfig.
src/lerobot/async_inference/policy_server.py Implements race-free pending-observation handoff and initializes similarity function from client instructions.
src/lerobot/async_inference/robot_client.py Sends similarity_fn_name to the server via RemotePolicyConfig.
src/lerobot/async_inference/constants.py Pure formatting/readability change for supported lists.
tests/async_inference/test_helpers.py Adds coverage for similarity function selection and disabled behavior.
tests/async_inference/test_policy_server.py Updates tests for the new pending-observation mechanism and in-progress filtering behavior.
docs/source/async.mdx Documents the new --similarity_fn_name CLI parameter and options.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 286 to 292
def observations_similar(
obs1: TimedObservation, obs2: TimedObservation, lerobot_features: dict[str, dict], atol: float = 1
obs1: TimedObservation,
obs2: TimedObservation,
lerobot_features: dict[str, dict],
similarity_fn,
atol: float = 1,
) -> bool:
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

observations_similar() now requires a similarity_fn argument with no default, which is a breaking API change compared to the previous signature. If this helper is used outside policy_server.py, consider making similarity_fn optional with a default that preserves prior behavior (euclidean distance) so existing callers don’t break.

Copilot uses AI. Check for mistakes.
Comment on lines +180 to +186
# Similarity function configuration (CLI-compatible, passed to server)
similarity_fn_name: str = field(
default="euclidean",
metadata={
"help": f"Name of similarity function to use. Options: {list(SIMILARITY_FUNCTIONS.keys())}"
},
)
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RobotClientConfig.__post_init__() validates aggregate_fn_name by resolving it, but similarity_fn_name is not validated. If a user passes an unknown similarity function via CLI, the failure will only surface on the server (and possibly as a less clear gRPC error). Consider validating similarity_fn_name in the client config as well (e.g., by checking SIMILARITY_FUNCTIONS / calling get_similarity_function).

Copilot uses AI. Check for mistakes.
Comment on lines +155 to +157
# Update server config with similarity function from client
self.config.similarity_fn_name = policy_specs.similarity_fn_name
self.config.similarity_fn = get_similarity_function(policy_specs.similarity_fn_name)
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RemotePolicyConfig is sent over the wire via pickle. If a newer server receives a RemotePolicyConfig pickled by an older client (without similarity_fn_name), policy_specs.similarity_fn_name will raise AttributeError here and the client will fail to connect. Consider using getattr(policy_specs, "similarity_fn_name", "euclidean") (and/or defaulting when missing) so mixed-version deployments still fall back to the previous behavior.

Suggested change
# Update server config with similarity function from client
self.config.similarity_fn_name = policy_specs.similarity_fn_name
self.config.similarity_fn = get_similarity_function(policy_specs.similarity_fn_name)
# Update server config with similarity function from client, defaulting
# for backward compatibility with older pickled RemotePolicyConfig objects.
similarity_fn_name = getattr(policy_specs, "similarity_fn_name", "euclidean")
self.config.similarity_fn_name = similarity_fn_name
self.config.similarity_fn = get_similarity_function(similarity_fn_name)

Copilot uses AI. Check for mistakes.
Comment on lines +276 to +283
def _compare_observation_states(
obs1_state: torch.Tensor,
obs2_state: torch.Tensor,
atol: float,
similarity_fn: callable,
) -> bool:
"""Check if two observation states are similar, under a tolerance threshold"""
return bool(torch.linalg.norm(obs1_state - obs2_state) < atol)
return bool(similarity_fn(obs1_state, obs2_state, atol))
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type annotation uses callable, which is the built-in function, not a typing type. Use collections.abc.Callable / typing.Callable with an explicit signature (e.g., (Tensor, Tensor, float) -> bool) to avoid confusing annotations and improve static checking.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or fixes to the project’s docs tests Problems with test coverage, failures, or improvements to testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants