Skip to content
This repository was archived by the owner on May 27, 2026. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 14 additions & 19 deletions compiler_gym/envs/llvm/llvm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
CostFunctionReward,
NormalizedReward,
)
from compiler_gym.errors import BenchmarkInitError
from compiler_gym.errors import BenchmarkInitError, SessionNotFound
from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv
from compiler_gym.spaces import Box, Commandline
from compiler_gym.spaces import Dict as DictSpace
Expand Down Expand Up @@ -363,7 +363,7 @@ def __init__(

def reset(self, *args, **kwargs):
try:
observation = super().reset(*args, **kwargs)
return super().reset(*args, **kwargs)
except ValueError as e:
# Catch and re-raise some known benchmark initialization errors with
# a more informative error type.
Expand All @@ -379,15 +379,6 @@ def reset(self, *args, **kwargs):
raise BenchmarkInitError(str(e)) from e
raise

# Resend the runtimes-per-observation session parameter, if it is a
# non-default value.
if self._runtimes_per_observation_count is not None:
self.runtime_observation_count = self._runtimes_per_observation_count
if self._runtimes_warmup_per_observation_count is not None:
self.runtime_warmup_runs_count = self._runtimes_warmup_per_observation_count

return observation

def make_benchmark(
self,
inputs: Union[
Expand Down Expand Up @@ -612,10 +603,12 @@ def runtime_observation_count(self) -> int:

@runtime_observation_count.setter
def runtime_observation_count(self, n: int) -> None:
if self.in_episode:
self.send_param("llvm.set_runtimes_per_observation_count", str(n))
# NOTE(cummins): Keep this after the send_param() call because
# send_param() will raise an error if the valid is invalid.
try:
self.send_param(
"llvm.set_runtimes_per_observation_count", str(n), resend_on_reset=True
)
except SessionNotFound:
pass # Not in session yet, will be sent on reset().
self._runtimes_per_observation_count = n

@property
Expand Down Expand Up @@ -648,12 +641,14 @@ def runtime_warmup_runs_count(self) -> int:

@runtime_warmup_runs_count.setter
def runtime_warmup_runs_count(self, n: int) -> None:
if self.in_episode:
try:
self.send_param(
"llvm.set_warmup_runs_count_per_runtime_observation", str(n)
"llvm.set_warmup_runs_count_per_runtime_observation",
str(n),
resend_on_reset=True,
)
# NOTE(cummins): Keep this after the send_param() call because
# send_param() will raise an error if the valid is invalid.
except SessionNotFound:
pass # Not in session yet, will be sent on reset().
self._runtimes_warmup_per_observation_count = n

def fork(self):
Expand Down
28 changes: 24 additions & 4 deletions compiler_gym/service/client_service_compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def __init__(

self._service_endpoint: Union[str, Path] = service
self._connection_settings = connection_settings or ConnectionOpts()
self._params_to_send_on_reset: List[SessionParameter] = []

self.service = service_connection or CompilerGymServiceConnection(
endpoint=self._service_endpoint,
Expand Down Expand Up @@ -788,6 +789,12 @@ def _call_with_error(
reply.new_action_space
)

# Re-send any session parameters that we marked as needing to be
# re-sent on reset(). Do this before any other initialization as they
# may affect the behavior of subsequent service calls.
if self._params_to_send_on_reset:
self.send_params(*[(p.key, p.value) for p in self._params_to_send_on_reset])

self.reward.reset(benchmark=self.benchmark, observation_view=self.observation)
if self.reward_space:
self.episode_reward = 0.0
Expand Down Expand Up @@ -1236,7 +1243,7 @@ def validate(self, state: Optional[CompilerEnvState] = None) -> ValidationResult
**validation,
)

def send_param(self, key: str, value: str) -> str:
def send_param(self, key: str, value: str, resend_on_reset: bool = False) -> str:
"""Send a single <key, value> parameter to the compiler service.

See :meth:`send_params() <compiler_gym.envs.ClientServiceCompilerEnv.send_params>`
Expand All @@ -1246,14 +1253,19 @@ def send_param(self, key: str, value: str) -> str:

:param value: The parameter value.

:param resend_on_reset: Whether to resend this parameter to the compiler
service on :code:`reset()`.

:return: The response from the compiler service.

:raises SessionNotFound: If called before :meth:`reset()
<compiler_gym.envs.ClientServiceCompilerEnv.reset>`.
"""
return self.send_params((key, value))[0]
return self.send_params((key, value), resend_on_reset=resend_on_reset)[0]

def send_params(self, *params: Iterable[Tuple[str, str]]) -> List[str]:
def send_params(
self, *params: Iterable[Tuple[str, str]], resend_on_reset: bool = False
) -> List[str]:
"""Send a list of <key, value> parameters to the compiler service.

This provides a mechanism to send messages to the backend compilation
Expand All @@ -1270,17 +1282,25 @@ def send_params(self, *params: Iterable[Tuple[str, str]]) -> List[str]:
:param params: A list of parameters, where each parameter is a
:code:`(key, value)` tuple.

:param resend_on_reset: Whether to resend this parameter to the compiler
service on :code:`reset()`.

:return: A list of string responses, one per parameter.

:raises SessionNotFound: If called before :meth:`reset()
<compiler_gym.envs.ClientServiceCompilerEnv.reset>`.
"""
params_to_send = [SessionParameter(key=k, value=v) for (k, v) in params]

if resend_on_reset:
self._params_to_send_on_reset += params_to_send

if not self.in_episode:
raise SessionNotFound("Must call reset() before send_params()")

request = SendSessionParameterRequest(
session_id=self._session_id,
parameter=[SessionParameter(key=k, value=v) for (k, v) in params],
parameter=params_to_send,
)
reply: SendSessionParameterReply = self.service(
self.service.stub.SendSessionParameter, request
Expand Down
2 changes: 2 additions & 0 deletions tests/llvm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ py_test(
deps = [
"//compiler_gym/envs/llvm",
"//compiler_gym/service:connection",
"//compiler_gym/spaces",
"//compiler_gym/util",
"//tests:test_main",
"//tests/pytest_plugins:llvm",
],
Expand Down
62 changes: 62 additions & 0 deletions tests/llvm/runtime_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
# LICENSE file in the root directory of this source tree.
"""Integrations tests for LLVM runtime support."""
from pathlib import Path
from typing import List

import numpy as np
import pytest
from flaky import flaky

from compiler_gym.envs.llvm import LlvmEnv, llvm_benchmark
from compiler_gym.spaces.reward import Reward
from compiler_gym.util.gym_type_hints import ActionType, ObservationType
from tests.test_main import main

pytest_plugins = ["tests.pytest_plugins.llvm"]
Expand Down Expand Up @@ -144,5 +147,64 @@ def test_default_runtime_observation_count_fork(env: LlvmEnv):
assert fkd.runtime_warmup_runs_count == wc


class RewardDerivedFromRuntime(Reward):
"""A custom reward space that is derived from the Runtime observation space."""

def __init__(self):
super().__init__(
name="runtimeseries",
observation_spaces=["Runtime"],
default_value=0,
min=None,
max=None,
default_negates_returns=True,
deterministic=False,
platform_dependent=True,
)
self.last_runtime_observation: List[float] = None

def reset(self, benchmark, observation_view) -> None:
del benchmark # unused
self.last_runtime_observation = observation_view["Runtime"]

def update(
self,
actions: List[ActionType],
observations: List[ObservationType],
observation_view,
) -> float:
del actions # unused
del observation_view # unused
self.last_runtime_observation = observations[0]
return 0


@flaky # runtime may fail
@pytest.mark.parametrize("runtime_observation_count", [1, 3, 5])
def test_correct_number_of_observations_during_reset(
env: LlvmEnv, runtime_observation_count: int
):
env.reward.add_space(RewardDerivedFromRuntime())
env.runtime_observation_count = runtime_observation_count
env.reset(reward_space="runtimeseries")
assert env.runtime_observation_count == runtime_observation_count

# Check that the number of observations that you are receive during reset()
# matches the amount that you asked for.
assert (
len(env.reward.spaces["runtimeseries"].last_runtime_observation)
== runtime_observation_count
)

# Check that the number of observations that you are receive during step()
# matches the amount that you asked for.
env.reward.spaces["runtimeseries"].last_runtime_observation = None
env.step(0)
assert (
len(env.reward.spaces["runtimeseries"].last_runtime_observation)
== runtime_observation_count
)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ filterwarnings =
error
ignore::pytest.PytestAssertRewriteWarning:
ignore::ResourceWarning:
ignore:SelectableGroups dict interface is deprecated. Use select:DeprecationWarning
# Global timeout applied to all test functions:
timeout = 300