diff --git a/compiler_gym/envs/llvm/llvm_env.py b/compiler_gym/envs/llvm/llvm_env.py index 49b55f075..2a77f43c8 100644 --- a/compiler_gym/envs/llvm/llvm_env.py +++ b/compiler_gym/envs/llvm/llvm_env.py @@ -28,7 +28,7 @@ CostFunctionReward, NormalizedReward, ) -from compiler_gym.errors import BenchmarkInitError +from compiler_gym.errors import BenchmarkInitError, SessionNotFound from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv from compiler_gym.spaces import Box, Commandline from compiler_gym.spaces import Dict as DictSpace @@ -363,7 +363,7 @@ def __init__( def reset(self, *args, **kwargs): try: - observation = super().reset(*args, **kwargs) + return super().reset(*args, **kwargs) except ValueError as e: # Catch and re-raise some known benchmark initialization errors with # a more informative error type. @@ -379,15 +379,6 @@ def reset(self, *args, **kwargs): raise BenchmarkInitError(str(e)) from e raise - # Resend the runtimes-per-observation session parameter, if it is a - # non-default value. - if self._runtimes_per_observation_count is not None: - self.runtime_observation_count = self._runtimes_per_observation_count - if self._runtimes_warmup_per_observation_count is not None: - self.runtime_warmup_runs_count = self._runtimes_warmup_per_observation_count - - return observation - def make_benchmark( self, inputs: Union[ @@ -612,10 +603,12 @@ def runtime_observation_count(self) -> int: @runtime_observation_count.setter def runtime_observation_count(self, n: int) -> None: - if self.in_episode: - self.send_param("llvm.set_runtimes_per_observation_count", str(n)) - # NOTE(cummins): Keep this after the send_param() call because - # send_param() will raise an error if the valid is invalid. + try: + self.send_param( + "llvm.set_runtimes_per_observation_count", str(n), resend_on_reset=True + ) + except SessionNotFound: + pass # Not in session yet, will be sent on reset(). self._runtimes_per_observation_count = n @property @@ -648,12 +641,14 @@ def runtime_warmup_runs_count(self) -> int: @runtime_warmup_runs_count.setter def runtime_warmup_runs_count(self, n: int) -> None: - if self.in_episode: + try: self.send_param( - "llvm.set_warmup_runs_count_per_runtime_observation", str(n) + "llvm.set_warmup_runs_count_per_runtime_observation", + str(n), + resend_on_reset=True, ) - # NOTE(cummins): Keep this after the send_param() call because - # send_param() will raise an error if the valid is invalid. + except SessionNotFound: + pass # Not in session yet, will be sent on reset(). self._runtimes_warmup_per_observation_count = n def fork(self): diff --git a/compiler_gym/service/client_service_compiler_env.py b/compiler_gym/service/client_service_compiler_env.py index 6532d92f2..c4cab3cc8 100644 --- a/compiler_gym/service/client_service_compiler_env.py +++ b/compiler_gym/service/client_service_compiler_env.py @@ -202,6 +202,7 @@ def __init__( self._service_endpoint: Union[str, Path] = service self._connection_settings = connection_settings or ConnectionOpts() + self._params_to_send_on_reset: List[SessionParameter] = [] self.service = service_connection or CompilerGymServiceConnection( endpoint=self._service_endpoint, @@ -788,6 +789,12 @@ def _call_with_error( reply.new_action_space ) + # Re-send any session parameters that we marked as needing to be + # re-sent on reset(). Do this before any other initialization as they + # may affect the behavior of subsequent service calls. + if self._params_to_send_on_reset: + self.send_params(*[(p.key, p.value) for p in self._params_to_send_on_reset]) + self.reward.reset(benchmark=self.benchmark, observation_view=self.observation) if self.reward_space: self.episode_reward = 0.0 @@ -1236,7 +1243,7 @@ def validate(self, state: Optional[CompilerEnvState] = None) -> ValidationResult **validation, ) - def send_param(self, key: str, value: str) -> str: + def send_param(self, key: str, value: str, resend_on_reset: bool = False) -> str: """Send a single parameter to the compiler service. See :meth:`send_params() ` @@ -1246,14 +1253,19 @@ def send_param(self, key: str, value: str) -> str: :param value: The parameter value. + :param resend_on_reset: Whether to resend this parameter to the compiler + service on :code:`reset()`. + :return: The response from the compiler service. :raises SessionNotFound: If called before :meth:`reset() `. """ - return self.send_params((key, value))[0] + return self.send_params((key, value), resend_on_reset=resend_on_reset)[0] - def send_params(self, *params: Iterable[Tuple[str, str]]) -> List[str]: + def send_params( + self, *params: Iterable[Tuple[str, str]], resend_on_reset: bool = False + ) -> List[str]: """Send a list of parameters to the compiler service. This provides a mechanism to send messages to the backend compilation @@ -1270,17 +1282,25 @@ def send_params(self, *params: Iterable[Tuple[str, str]]) -> List[str]: :param params: A list of parameters, where each parameter is a :code:`(key, value)` tuple. + :param resend_on_reset: Whether to resend this parameter to the compiler + service on :code:`reset()`. + :return: A list of string responses, one per parameter. :raises SessionNotFound: If called before :meth:`reset() `. """ + params_to_send = [SessionParameter(key=k, value=v) for (k, v) in params] + + if resend_on_reset: + self._params_to_send_on_reset += params_to_send + if not self.in_episode: raise SessionNotFound("Must call reset() before send_params()") request = SendSessionParameterRequest( session_id=self._session_id, - parameter=[SessionParameter(key=k, value=v) for (k, v) in params], + parameter=params_to_send, ) reply: SendSessionParameterReply = self.service( self.service.stub.SendSessionParameter, request diff --git a/tests/llvm/BUILD b/tests/llvm/BUILD index d3639abde..1a5182e4d 100644 --- a/tests/llvm/BUILD +++ b/tests/llvm/BUILD @@ -259,6 +259,8 @@ py_test( deps = [ "//compiler_gym/envs/llvm", "//compiler_gym/service:connection", + "//compiler_gym/spaces", + "//compiler_gym/util", "//tests:test_main", "//tests/pytest_plugins:llvm", ], diff --git a/tests/llvm/runtime_test.py b/tests/llvm/runtime_test.py index 2c1384fba..833300359 100644 --- a/tests/llvm/runtime_test.py +++ b/tests/llvm/runtime_test.py @@ -4,12 +4,15 @@ # LICENSE file in the root directory of this source tree. """Integrations tests for LLVM runtime support.""" from pathlib import Path +from typing import List import numpy as np import pytest from flaky import flaky from compiler_gym.envs.llvm import LlvmEnv, llvm_benchmark +from compiler_gym.spaces.reward import Reward +from compiler_gym.util.gym_type_hints import ActionType, ObservationType from tests.test_main import main pytest_plugins = ["tests.pytest_plugins.llvm"] @@ -144,5 +147,64 @@ def test_default_runtime_observation_count_fork(env: LlvmEnv): assert fkd.runtime_warmup_runs_count == wc +class RewardDerivedFromRuntime(Reward): + """A custom reward space that is derived from the Runtime observation space.""" + + def __init__(self): + super().__init__( + name="runtimeseries", + observation_spaces=["Runtime"], + default_value=0, + min=None, + max=None, + default_negates_returns=True, + deterministic=False, + platform_dependent=True, + ) + self.last_runtime_observation: List[float] = None + + def reset(self, benchmark, observation_view) -> None: + del benchmark # unused + self.last_runtime_observation = observation_view["Runtime"] + + def update( + self, + actions: List[ActionType], + observations: List[ObservationType], + observation_view, + ) -> float: + del actions # unused + del observation_view # unused + self.last_runtime_observation = observations[0] + return 0 + + +@flaky # runtime may fail +@pytest.mark.parametrize("runtime_observation_count", [1, 3, 5]) +def test_correct_number_of_observations_during_reset( + env: LlvmEnv, runtime_observation_count: int +): + env.reward.add_space(RewardDerivedFromRuntime()) + env.runtime_observation_count = runtime_observation_count + env.reset(reward_space="runtimeseries") + assert env.runtime_observation_count == runtime_observation_count + + # Check that the number of observations that you are receive during reset() + # matches the amount that you asked for. + assert ( + len(env.reward.spaces["runtimeseries"].last_runtime_observation) + == runtime_observation_count + ) + + # Check that the number of observations that you are receive during step() + # matches the amount that you asked for. + env.reward.spaces["runtimeseries"].last_runtime_observation = None + env.step(0) + assert ( + len(env.reward.spaces["runtimeseries"].last_runtime_observation) + == runtime_observation_count + ) + + if __name__ == "__main__": main() diff --git a/tox.ini b/tox.ini index 1cfff8d93..7dc8aceb8 100644 --- a/tox.ini +++ b/tox.ini @@ -14,5 +14,6 @@ filterwarnings = error ignore::pytest.PytestAssertRewriteWarning: ignore::ResourceWarning: + ignore:SelectableGroups dict interface is deprecated. Use select:DeprecationWarning # Global timeout applied to all test functions: timeout = 300