diff --git a/src/agentlab/agents/generic_agent/reproducibility_agent.py b/src/agentlab/agents/generic_agent/reproducibility_agent.py index b484ac7de..4cab3435d 100644 --- a/src/agentlab/agents/generic_agent/reproducibility_agent.py +++ b/src/agentlab/agents/generic_agent/reproducibility_agent.py @@ -43,7 +43,7 @@ def __init__(self, old_messages, delay=1) -> None: self.old_messages = old_messages self.delay = delay - def invoke(self, messages: list): + def __call__(self, messages: list): self.new_messages = copy(messages) if len(messages) >= len(self.old_messages): diff --git a/src/agentlab/llm/base_api.py b/src/agentlab/llm/base_api.py new file mode 100644 index 000000000..9c1ebf5ff --- /dev/null +++ b/src/agentlab/llm/base_api.py @@ -0,0 +1,33 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass + + +class AbstractChatModel(ABC): + @abstractmethod + def __call__(self, messages: list[dict]) -> dict: + pass + + def get_stats(self): + return {} + + +@dataclass +class BaseModelArgs(ABC): + """Base class for all model arguments.""" + + model_name: str + max_total_tokens: int = None + max_input_tokens: int = None + max_new_tokens: int = None + temperature: float = 0.1 + vision_support: bool = False + + @abstractmethod + def make_model(self) -> AbstractChatModel: + pass + + def prepare_server(self): + pass + + def close_server(self): + pass diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 139b2ca5e..a4df0a977 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -2,14 +2,17 @@ import os import re import time -from abc import ABC, abstractmethod from dataclasses import dataclass +from functools import partial +from typing import Optional import openai +from huggingface_hub import InferenceClient from openai import AzureOpenAI, OpenAI import agentlab.llm.tracking as tracking -from agentlab.llm.huggingface_utils import HuggingFaceURLChatModel +from agentlab.llm.base_api import AbstractChatModel, BaseModelArgs +from agentlab.llm.huggingface_utils import HFBaseChatModel def make_system_message(content: str) -> dict: @@ -24,10 +27,10 @@ def make_assistant_message(content: str) -> dict: return dict(role="assistant", content=content) -class CheatMiniWoBLLM: +class CheatMiniWoBLLM(AbstractChatModel): """For unit-testing purposes only. It only work with miniwob.click-test task.""" - def invoke(self, messages) -> str: + def __call__(self, messages) -> str: prompt = messages[-1]["content"] match = re.search(r"^\s*\[(\d+)\].*button", prompt, re.MULTILINE | re.IGNORECASE) @@ -44,12 +47,6 @@ def invoke(self, messages) -> str: """ return make_assistant_message(answer) - def __call__(self, messages) -> str: - return self.invoke(messages) - - def get_stats(self): - return {} - @dataclass class CheatMiniWoBLLMArgs: @@ -68,28 +65,6 @@ def close_server(self): pass -@dataclass -class BaseModelArgs(ABC): - """Base class for all model arguments.""" - - model_name: str - max_total_tokens: int = None - max_input_tokens: int = None - max_new_tokens: int = None - temperature: float = 0.1 - vision_support: bool = False - - @abstractmethod - def make_model(self) -> "ChatModel": - pass - - def prepare_server(self): - pass - - def close_server(self): - pass - - @dataclass class OpenRouterModelArgs(BaseModelArgs): """Serializable object for instantiating a generic chat model with an OpenAI @@ -221,7 +196,7 @@ def handle_error(error, itr, min_retry_wait_time, max_retry): return error_type -class ChatModel: +class ChatModel(AbstractChatModel): def __init__( self, model_name, @@ -310,9 +285,6 @@ def __call__(self, messages: list[dict]) -> dict: return make_assistant_message(completion.choices[0].message.content) - def invoke(self, messages: list[dict]) -> dict: - return self(messages) - def get_stats(self): return { "n_retry_llm": self.retries, @@ -401,3 +373,26 @@ def __init__( client_args=client_args, pricing_func=tracking.get_pricing_openai, ) + + +class HuggingFaceURLChatModel(HFBaseChatModel): + def __init__( + self, + model_name: str, + model_url: str, + token: Optional[str] = None, + temperature: Optional[int] = 1e-1, + max_new_tokens: Optional[int] = 512, + n_retry_server: Optional[int] = 4, + ): + super().__init__(model_name, n_retry_server) + if temperature < 1e-3: + logging.warning("Models might behave weirdly when temperature is too low.") + + if token is None: + token = os.environ["TGI_TOKEN"] + + client = InferenceClient(model=model_url, token=token) + self.llm = partial( + client.text_generation, temperature=temperature, max_new_tokens=max_new_tokens + ) diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index ce4dae067..470324bd6 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -1,17 +1,15 @@ import logging -import os import time -from functools import partial from typing import Any, List, Optional -from huggingface_hub import InferenceClient from pydantic import Field from transformers import AutoTokenizer, GPT2TokenizerFast +from agentlab.llm.base_api import AbstractChatModel from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template -class HFBaseChatModel: +class HFBaseChatModel(AbstractChatModel): """ Custom LLM Chatbot that can interface with HuggingFace models. @@ -94,101 +92,6 @@ def __call__( def _llm_type(self): return "huggingface" - def invoke(self, messages: list[dict]) -> dict: - return self(messages) - - def get_stats(self): - return {} - - -class HuggingFaceURLChatModel(HFBaseChatModel): - def __init__( - self, - model_name: str, - model_url: str, - token: Optional[str] = None, - temperature: Optional[int] = 1e-1, - max_new_tokens: Optional[int] = 512, - n_retry_server: Optional[int] = 4, - ): - super().__init__(model_name, n_retry_server) - if temperature < 1e-3: - logging.warning("Models might behave weirdly when temperature is too low.") - - if token is None: - token = os.environ["TGI_TOKEN"] - - client = InferenceClient(model=model_url, token=token) - self.llm = partial( - client.text_generation, temperature=temperature, max_new_tokens=max_new_tokens - ) - - -# def _convert_messages_to_dict(messages, column_remap={}): -# """ -# Converts a list of message objects into a list of dictionaries, categorizing each message by its role. - -# Each message is expected to be an instance of one of the following types: SystemMessage, HumanMessage, AIMessage. -# The function maps each message to its corresponding role ('system', 'user', 'assistant') and formats it into a dictionary. - -# Args: -# messages (list): A list of message objects. -# column_remap (dict): A dictionary that maps the column names to the desired output format. - -# Returns: -# list: A list of dictionaries where each dictionary represents a message and contains 'role' and 'content' keys. - -# Raises: -# ValueError: If an unsupported message type is encountered. - -# Example: -# >>> messages = [SystemMessage("System initializing..."), HumanMessage("Hello!"), AIMessage("How can I assist?")] -# >>> _convert_messages_to_dict(messages) -# [ -# {"role": "system", "content": "System initializing..."}, -# {"role": "user", "content": "Hello!"}, -# {"role": "assistant", "content": "How can I assist?"} -# ] -# """ - -# human_key = column_remap.get("HumanMessage", "user") -# ai_message_key = column_remap.get("AIMessage", "assistant") -# role_key = column_remap.get("role", "role") -# text_key = column_remap.get("text", "content") -# image_key = column_remap.get("image", "media_url") - -# # Mapping of message types to roles -# message_type_to_role = { -# SystemMessage: "system", -# HumanMessage: human_key, -# AIMessage: ai_message_key, -# } - -# def convert_format_vision(message_content, role, text_key, image_key): -# result = {} -# result["type"] = role -# for item in message_content: -# if item["type"] == "text": -# result[text_key] = item["text"] -# elif item["type"] == "image_url": -# result[image_key] = item["image_url"] -# return result - -# chat = [] -# for message in messages: -# message_role = message_type_to_role.get(type(message)) -# if message_role: -# if isinstance(message.content, str): -# chat.append({role_key: message_role, text_key: message.content}) -# else: -# chat.append( -# convert_format_vision(message.content, message_role, text_key, image_key) -# ) -# else: -# raise ValueError(f"Message type {type(message)} not supported") - -# return chat - def _prepend_system_to_first_user(messages, column_remap={}): # Initialize an index for the system message diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 4b876b54f..c3d750098 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -79,7 +79,7 @@ def retry( """ tries = 0 while tries < n_retry: - answer = chat.invoke(messages) + answer = chat(messages) messages.append(answer) # TODO: could we change this to not use inplace modifications ? try: diff --git a/tests/agents/test_agent.py b/tests/agents/test_agent.py index ae1732892..0b2c31f28 100644 --- a/tests/agents/test_agent.py +++ b/tests/agents/test_agent.py @@ -50,7 +50,7 @@ class CheatMiniWoBLLM_ParseRetry: n_retry: int retry_count: int = 0 - def invoke(self, messages) -> str: + def __call__(self, messages) -> str: if self.retry_count < self.n_retry: self.retry_count += 1 return dict(role="assistant", content="I'm retrying") @@ -71,9 +71,6 @@ def invoke(self, messages) -> str: """ return dict(role="assistant", content=answer) - def __call__(self, messages) -> str: - return self.invoke(messages) - def get_stats(self): return {} @@ -94,7 +91,7 @@ class CheatLLM_LLMError: n_retry: int = 0 success: bool = False - def invoke(self, messages) -> str: + def __call__(self, messages) -> str: if self.success: prompt = messages[1].get("content", "") match = re.search(r"^\s*\[(\d+)\].*button", prompt, re.MULTILINE | re.IGNORECASE) @@ -113,9 +110,6 @@ def invoke(self, messages) -> str: return dict(role="assistant", content=answer) raise OpenAIError("LLM failed to respond") - def __call__(self, messages) -> str: - return self.invoke(messages) - def get_stats(self): return {"n_llm_retry": self.n_retry, "n_llm_busted_retry": int(not self.success)} diff --git a/tests/llm/test_chat_api.py b/tests/llm/test_chat_api.py index b49f35887..f06fa7fa4 100644 --- a/tests/llm/test_chat_api.py +++ b/tests/llm/test_chat_api.py @@ -35,7 +35,7 @@ def test_api_model_args_azure(): make_system_message("You are an helpful virtual assistant"), make_user_message("Give the third prime number"), ] - answer = model.invoke(messages) + answer = model(messages) assert "5" in answer.get("content") @@ -56,6 +56,6 @@ def test_api_model_args_openai(): make_system_message("You are an helpful virtual assistant"), make_user_message("Give the third prime number"), ] - answer = model.invoke(messages) + answer = model(messages) assert "5" in answer.get("content") diff --git a/tests/llm/test_llm_utils.py b/tests/llm/test_llm_utils.py index 1314bea03..7e5bb87cc 100644 --- a/tests/llm/test_llm_utils.py +++ b/tests/llm/test_llm_utils.py @@ -93,9 +93,12 @@ def test_compress_string(): # Mock ChatOpenAI class class MockChatOpenAI: - def invoke(self, messages): + def call(self, messages): return "mocked response" + def __call__(self, messages): + return self.call(messages) + def mock_parser(answer): if answer == "correct content": @@ -126,7 +129,7 @@ def mock_rate_limit_error(message: str, status_code: Literal[429] = 429) -> Rate # Test to ensure function stops retrying after reaching the max wait time # def test_rate_limit_max_wait_time(): # mock_chat = MockChatOpenAI() -# mock_chat.invoke = Mock( +# mock_chat.call = Mock( # side_effect=mock_rate_limit_error("Rate limit reached. Please try again in 2s.") # ) @@ -141,12 +144,12 @@ def mock_rate_limit_error(message: str, status_code: Literal[429] = 429) -> Rate # ) # # The function should stop retrying after 2 attempts (6s each time, 12s total which is greater than the 10s max wait time) -# assert mock_chat.invoke.call_count == 3 +# assert mock_chat.call.call_count == 3 # def test_rate_limit_success(): # mock_chat = MockChatOpenAI() -# mock_chat.invoke = Mock( +# mock_chat.call = Mock( # side_effect=[ # mock_rate_limit_error("Rate limit reached. Please try again in 2s."), # make_system_message("correct content"), @@ -163,7 +166,7 @@ def mock_rate_limit_error(message: str, status_code: Literal[429] = 429) -> Rate # ) # assert result == "Parsed value" -# assert mock_chat.invoke.call_count == 2 +# assert mock_chat.call.call_count == 2 # Mock a successful parser response to test function exit before max retries @@ -172,7 +175,7 @@ def test_successful_parse_before_max_retries(): # mock a chat that returns the wrong content the first 2 time, but the right # content on the 3rd time - mock_chat.invoke = Mock( + mock_chat.call = Mock( side_effect=[ make_system_message("wrong content"), make_system_message("wrong content"), @@ -183,7 +186,7 @@ def test_successful_parse_before_max_retries(): result = llm_utils.retry(mock_chat, [], 5, mock_parser) assert result == "Parsed value" - assert mock_chat.invoke.call_count == 3 + assert mock_chat.call.call_count == 3 def test_unsuccessful_parse_before_max_retries(): @@ -191,7 +194,7 @@ def test_unsuccessful_parse_before_max_retries(): # mock a chat that returns the wrong content the first 2 time, but the right # content on the 3rd time - mock_chat.invoke = Mock( + mock_chat.call = Mock( side_effect=[ make_system_message("wrong content"), make_system_message("wrong content"), @@ -201,12 +204,12 @@ def test_unsuccessful_parse_before_max_retries(): with pytest.raises(llm_utils.ParseError): result = llm_utils.retry(mock_chat, [], 2, mock_parser) - assert mock_chat.invoke.call_count == 2 + assert mock_chat.call.call_count == 2 def test_retry_parse_raises(): mock_chat = MockChatOpenAI() - mock_chat.invoke = Mock(return_value=make_system_message("mocked response")) + mock_chat.call = Mock(return_value=make_system_message("mocked response")) parser_raises = Mock(side_effect=ValueError("Parser error")) with pytest.raises(ValueError): diff --git a/tests/llm/test_tracking.py b/tests/llm/test_tracking.py index cc5abd36f..01ebcc067 100644 --- a/tests/llm/test_tracking.py +++ b/tests/llm/test_tracking.py @@ -136,7 +136,7 @@ def test_openai_chat_model(): make_user_message("Give the third prime number"), ] with tracking.set_tracker() as tracker: - answer = chat_model.invoke(messages) + answer = chat_model(messages) assert "5" in answer.get("content") assert tracker.stats["cost"] > 0 @@ -161,7 +161,7 @@ def test_azure_chat_model(): make_user_message("Give the third prime number"), ] with tracking.set_tracker() as tracker: - answer = chat_model.invoke(messages) + answer = chat_model(messages) assert "5" in answer.get("content") assert tracker.stats["cost"] > 0 @@ -178,6 +178,6 @@ def test_openrouter_chat_model(): make_user_message("Give the third prime number"), ] with tracking.set_tracker() as tracker: - answer = chat_model.invoke(messages) + answer = chat_model(messages) assert "5" in answer.get("content") assert tracker.stats["cost"] > 0