diff --git a/src/agentlab/agents/dynamic_prompting.py b/src/agentlab/agents/dynamic_prompting.py index 54d52f2cd..5abd7ea60 100644 --- a/src/agentlab/agents/dynamic_prompting.py +++ b/src/agentlab/agents/dynamic_prompting.py @@ -366,16 +366,50 @@ def __init__(self, bid, visible: bool = True, prefix="") -> None: """ +class Tabs(PromptElement): + def __init__(self, obs, visible: bool = True, prefix="") -> None: + super().__init__(visible=visible) + self.obs = obs + self.prefix = prefix + + @property + def _prompt(self) -> str: + # by implementing this as a property, it's only coputed if visible + prompt_pieces = [f"\n{self.prefix}Currently open tabs:"] + for page_index, (page_url, page_title) in enumerate( + zip(self.obs["open_pages_urls"], self.obs["open_pages_titles"]) + ): + active_or_not = " (active tab)" if page_index == self.obs["active_page_index"] else "" + prompt_piece = f"""\ +Tab {page_index}{active_or_not}: + Title: {page_title} + URL: {page_url} +""" + prompt_pieces.append(prompt_piece) + self._prompt = "\n".join(prompt_pieces) + + +def has_tab_action(action_set: bgym.HighLevelActionSetArgs): + return "tab" in action_set.subsets + + class Observation(Shrinkable): """Observation of the current step. Contains the html, the accessibility tree and the error logs. """ - def __init__(self, obs, flags: ObsFlags) -> None: + def __init__(self, obs, flags: ObsFlags, use_tabs=False) -> None: super().__init__() self.flags = flags self.obs = obs + + self.tabs = Tabs( + obs, + visible=use_tabs, + prefix="## ", + ) + self.html = HTML( obs[flags.html_type], visible_elements_only=flags.filter_visible_elements_only, @@ -409,7 +443,7 @@ def shrink(self): def _prompt(self) -> str: return f""" # Observation of current step: -{self.html.prompt}{self.ax_tree.prompt}{self.focused_element.prompt}{self.error.prompt} +{self.tabs}{self.html.prompt}{self.ax_tree.prompt}{self.focused_element.prompt}{self.error.prompt} """ diff --git a/src/agentlab/agents/generic_agent/generic_agent_prompt.py b/src/agentlab/agents/generic_agent/generic_agent_prompt.py index a655b42f3..50eeeed21 100644 --- a/src/agentlab/agents/generic_agent/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent/generic_agent_prompt.py @@ -74,7 +74,11 @@ def __init__( obs_history[-1]["goal"], extra_instructions=flags.extra_instructions ) - self.obs = dp.Observation(obs_history[-1], self.flags.obs) + self.obs = dp.Observation( + obs_history[-1], + self.flags.obs, + use_tabs=dp.has_tab_action(self.flags.action.action_set), + ) self.action_prompt = dp.ActionPrompt(action_set, action_flags=flags.action) diff --git a/tests/experiments/test_reproducibility_util.py b/tests/experiments/test_reproducibility_util.py index 57299ae3e..6008bb30e 100644 --- a/tests/experiments/test_reproducibility_util.py +++ b/tests/experiments/test_reproducibility_util.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize( "benchmark_name", - ["miniwob_all", "workarena_l1", "webarena", "visualwebarena"], + ["miniwob", "workarena_l1", "webarena", "visualwebarena"], ) def test_get_reproducibility_info(benchmark_name):