feat(async_inference): add configurable similarity function for observation filtering#3484
feat(async_inference): add configurable similarity function for observation filtering#3484AriannaPietrasanta wants to merge 1 commit intohuggingface:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a configurable observation-similarity function to the async inference stack so observation filtering can be tuned (or disabled) for high-DOF robots, and refactors the policy server’s observation handoff to avoid races while actions are in flight.
Changes:
- Introduces a
SIMILARITY_FUNCTIONSregistry +get_similarity_function()and propagatessimilarity_fn_namethrough client/server configs. - Updates observation similarity checking to use an injected similarity callable instead of a hardcoded L2 norm.
- Replaces the server’s
Queue(maxsize=1)with a lock +Event+ in-progress flag for atomic observation consumption and buffering rules; updates tests/docs accordingly.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
src/lerobot/async_inference/configs.py |
Adds similarity function registry + config plumbing for similarity_fn_name. |
src/lerobot/async_inference/helpers.py |
Refactors similarity checks to accept a similarity_fn callable; extends RemotePolicyConfig. |
src/lerobot/async_inference/policy_server.py |
Implements race-free pending-observation handoff and initializes similarity function from client instructions. |
src/lerobot/async_inference/robot_client.py |
Sends similarity_fn_name to the server via RemotePolicyConfig. |
src/lerobot/async_inference/constants.py |
Pure formatting/readability change for supported lists. |
tests/async_inference/test_helpers.py |
Adds coverage for similarity function selection and disabled behavior. |
tests/async_inference/test_policy_server.py |
Updates tests for the new pending-observation mechanism and in-progress filtering behavior. |
docs/source/async.mdx |
Documents the new --similarity_fn_name CLI parameter and options. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def observations_similar( | ||
| obs1: TimedObservation, obs2: TimedObservation, lerobot_features: dict[str, dict], atol: float = 1 | ||
| obs1: TimedObservation, | ||
| obs2: TimedObservation, | ||
| lerobot_features: dict[str, dict], | ||
| similarity_fn, | ||
| atol: float = 1, | ||
| ) -> bool: |
There was a problem hiding this comment.
observations_similar() now requires a similarity_fn argument with no default, which is a breaking API change compared to the previous signature. If this helper is used outside policy_server.py, consider making similarity_fn optional with a default that preserves prior behavior (euclidean distance) so existing callers don’t break.
| # Similarity function configuration (CLI-compatible, passed to server) | ||
| similarity_fn_name: str = field( | ||
| default="euclidean", | ||
| metadata={ | ||
| "help": f"Name of similarity function to use. Options: {list(SIMILARITY_FUNCTIONS.keys())}" | ||
| }, | ||
| ) |
There was a problem hiding this comment.
RobotClientConfig.__post_init__() validates aggregate_fn_name by resolving it, but similarity_fn_name is not validated. If a user passes an unknown similarity function via CLI, the failure will only surface on the server (and possibly as a less clear gRPC error). Consider validating similarity_fn_name in the client config as well (e.g., by checking SIMILARITY_FUNCTIONS / calling get_similarity_function).
| # Update server config with similarity function from client | ||
| self.config.similarity_fn_name = policy_specs.similarity_fn_name | ||
| self.config.similarity_fn = get_similarity_function(policy_specs.similarity_fn_name) |
There was a problem hiding this comment.
RemotePolicyConfig is sent over the wire via pickle. If a newer server receives a RemotePolicyConfig pickled by an older client (without similarity_fn_name), policy_specs.similarity_fn_name will raise AttributeError here and the client will fail to connect. Consider using getattr(policy_specs, "similarity_fn_name", "euclidean") (and/or defaulting when missing) so mixed-version deployments still fall back to the previous behavior.
| # Update server config with similarity function from client | |
| self.config.similarity_fn_name = policy_specs.similarity_fn_name | |
| self.config.similarity_fn = get_similarity_function(policy_specs.similarity_fn_name) | |
| # Update server config with similarity function from client, defaulting | |
| # for backward compatibility with older pickled RemotePolicyConfig objects. | |
| similarity_fn_name = getattr(policy_specs, "similarity_fn_name", "euclidean") | |
| self.config.similarity_fn_name = similarity_fn_name | |
| self.config.similarity_fn = get_similarity_function(similarity_fn_name) |
| def _compare_observation_states( | ||
| obs1_state: torch.Tensor, | ||
| obs2_state: torch.Tensor, | ||
| atol: float, | ||
| similarity_fn: callable, | ||
| ) -> bool: | ||
| """Check if two observation states are similar, under a tolerance threshold""" | ||
| return bool(torch.linalg.norm(obs1_state - obs2_state) < atol) | ||
| return bool(similarity_fn(obs1_state, obs2_state, atol)) |
There was a problem hiding this comment.
Type annotation uses callable, which is the built-in function, not a typing type. Use collections.abc.Callable / typing.Callable with an explicit signature (e.g., (Tensor, Tensor, float) -> bool) to avoid confusing annotations and improve static checking.
Title
feat(async_inference): add configurable similarity function for observation filtering
Summary / Motivation
On high-DOF real-robot deployments (30+ joints), the async inference server's hardcoded Euclidean distance filter was too aggressive: because joint-space distances across many dimensions sum to a large norm even for small per-joint deltas, incoming observations were almost always flagged as "similar" to the previous one and skipped. In practice this meant inference was only re-triggered once the action queue was fully depleted, causing jerky, reactive behaviour instead of smooth, continuously-updated control.
This PR introduces a
SIMILARITY_FUNCTIONSregistry inconfigs.pyand a--similarity_fn_nameCLI parameter (passed from the robot client to the policy server viaSendPolicyInstructions). Built-in options areeuclidean(default, same behaviour as before) anddisabled(always run inference). Users can add custom functions directly to the registry. As a side-effect, theQueue(maxsize=1)observation buffer inpolicy_server.pyis replaced with a lock + threading.Event approach that closes a race condition where a new observation could be accepted while the previous inference result was still in transit to the client.Related issues
None.
What changed
configs.py: newSIMILARITY_FUNCTIONSdict andget_similarity_function()factory;similarity_fn_namefield added toPolicyServerConfigandRobotClientConfig; field propagated into_server_config_dict().helpers.py:_compare_observation_states()andobservations_similar()now accept asimilarity_fncallable instead of hardcoding L2 norm;RemotePolicyConfiggetssimilarity_fn_namefield.policy_server.py:Queue(maxsize=1)replaced with_state_lock+_pending_obs+_inference_in_progressflag +_obs_availableEvent for race-free obs handoff; similarity function is initialised from theSendPolicyInstructionspayload;GetActionsatomically grabs the pending obs and sets the in-progress flag; the error path resets the flag.docs/source/async.mdx:--similarity_fn_namedocumented in both CLI examples and the parameter description list.How was this tested (or how to run locally)
tests/async_inference/test_helpers.py,tests/async_inference/test_policy_server.pypytest -q tests/async_inference/Checklist (required before merge)
pre-commit run -a)pytest)Reviewer notes
policy_server.pyis the most impactful change; please focus there. The_inference_in_progressflag is intentionally keptTruebetween inference completion and the client's nextGetActionscall, so that stale observations are not buffered while actions are in flight.euclideanis the default and reproduces the previous behaviour exactly.