Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/agentlab/agents/generic_agent/reproducibility_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, old_messages, delay=1) -> None:
self.old_messages = old_messages
self.delay = delay

def invoke(self, messages: list):
def __call__(self, messages: list):
self.new_messages = copy(messages)

if len(messages) >= len(self.old_messages):
Expand Down
33 changes: 33 additions & 0 deletions src/agentlab/llm/base_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass


class AbstractChatModel(ABC):
@abstractmethod
def __call__(self, messages: list[dict]) -> dict:
pass

def get_stats(self):
return {}


@dataclass
class BaseModelArgs(ABC):
"""Base class for all model arguments."""

model_name: str
max_total_tokens: int = None
max_input_tokens: int = None
max_new_tokens: int = None
temperature: float = 0.1
vision_support: bool = False

@abstractmethod
def make_model(self) -> AbstractChatModel:
pass

def prepare_server(self):
pass

def close_server(self):
pass
67 changes: 31 additions & 36 deletions src/agentlab/llm/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
import os
import re
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import partial
from typing import Optional

import openai
from huggingface_hub import InferenceClient
from openai import AzureOpenAI, OpenAI

import agentlab.llm.tracking as tracking
from agentlab.llm.huggingface_utils import HuggingFaceURLChatModel
from agentlab.llm.base_api import AbstractChatModel, BaseModelArgs
from agentlab.llm.huggingface_utils import HFBaseChatModel


def make_system_message(content: str) -> dict:
Expand All @@ -24,10 +27,10 @@ def make_assistant_message(content: str) -> dict:
return dict(role="assistant", content=content)


class CheatMiniWoBLLM:
class CheatMiniWoBLLM(AbstractChatModel):
"""For unit-testing purposes only. It only work with miniwob.click-test task."""

def invoke(self, messages) -> str:
def __call__(self, messages) -> str:
prompt = messages[-1]["content"]
match = re.search(r"^\s*\[(\d+)\].*button", prompt, re.MULTILINE | re.IGNORECASE)

Expand All @@ -44,12 +47,6 @@ def invoke(self, messages) -> str:
"""
return make_assistant_message(answer)

def __call__(self, messages) -> str:
return self.invoke(messages)

def get_stats(self):
return {}


@dataclass
class CheatMiniWoBLLMArgs:
Expand All @@ -68,28 +65,6 @@ def close_server(self):
pass


@dataclass
class BaseModelArgs(ABC):
"""Base class for all model arguments."""

model_name: str
max_total_tokens: int = None
max_input_tokens: int = None
max_new_tokens: int = None
temperature: float = 0.1
vision_support: bool = False

@abstractmethod
def make_model(self) -> "ChatModel":
pass

def prepare_server(self):
pass

def close_server(self):
pass


@dataclass
class OpenRouterModelArgs(BaseModelArgs):
"""Serializable object for instantiating a generic chat model with an OpenAI
Expand Down Expand Up @@ -221,7 +196,7 @@ def handle_error(error, itr, min_retry_wait_time, max_retry):
return error_type


class ChatModel:
class ChatModel(AbstractChatModel):
def __init__(
self,
model_name,
Expand Down Expand Up @@ -310,9 +285,6 @@ def __call__(self, messages: list[dict]) -> dict:

return make_assistant_message(completion.choices[0].message.content)

def invoke(self, messages: list[dict]) -> dict:
return self(messages)

def get_stats(self):
return {
"n_retry_llm": self.retries,
Expand Down Expand Up @@ -401,3 +373,26 @@ def __init__(
client_args=client_args,
pricing_func=tracking.get_pricing_openai,
)


class HuggingFaceURLChatModel(HFBaseChatModel):
def __init__(
self,
model_name: str,
model_url: str,
token: Optional[str] = None,
temperature: Optional[int] = 1e-1,
max_new_tokens: Optional[int] = 512,
n_retry_server: Optional[int] = 4,
):
super().__init__(model_name, n_retry_server)
if temperature < 1e-3:
logging.warning("Models might behave weirdly when temperature is too low.")

if token is None:
token = os.environ["TGI_TOKEN"]

client = InferenceClient(model=model_url, token=token)
self.llm = partial(
client.text_generation, temperature=temperature, max_new_tokens=max_new_tokens
)
101 changes: 2 additions & 99 deletions src/agentlab/llm/huggingface_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import logging
import os
import time
from functools import partial
from typing import Any, List, Optional

from huggingface_hub import InferenceClient
from pydantic import Field
from transformers import AutoTokenizer, GPT2TokenizerFast

from agentlab.llm.base_api import AbstractChatModel
from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template


class HFBaseChatModel:
class HFBaseChatModel(AbstractChatModel):
"""
Custom LLM Chatbot that can interface with HuggingFace models.

Expand Down Expand Up @@ -94,101 +92,6 @@ def __call__(
def _llm_type(self):
return "huggingface"

def invoke(self, messages: list[dict]) -> dict:
return self(messages)

def get_stats(self):
return {}


class HuggingFaceURLChatModel(HFBaseChatModel):
def __init__(
self,
model_name: str,
model_url: str,
token: Optional[str] = None,
temperature: Optional[int] = 1e-1,
max_new_tokens: Optional[int] = 512,
n_retry_server: Optional[int] = 4,
):
super().__init__(model_name, n_retry_server)
if temperature < 1e-3:
logging.warning("Models might behave weirdly when temperature is too low.")

if token is None:
token = os.environ["TGI_TOKEN"]

client = InferenceClient(model=model_url, token=token)
self.llm = partial(
client.text_generation, temperature=temperature, max_new_tokens=max_new_tokens
)


# def _convert_messages_to_dict(messages, column_remap={}):
# """
# Converts a list of message objects into a list of dictionaries, categorizing each message by its role.

# Each message is expected to be an instance of one of the following types: SystemMessage, HumanMessage, AIMessage.
# The function maps each message to its corresponding role ('system', 'user', 'assistant') and formats it into a dictionary.

# Args:
# messages (list): A list of message objects.
# column_remap (dict): A dictionary that maps the column names to the desired output format.

# Returns:
# list: A list of dictionaries where each dictionary represents a message and contains 'role' and 'content' keys.

# Raises:
# ValueError: If an unsupported message type is encountered.

# Example:
# >>> messages = [SystemMessage("System initializing..."), HumanMessage("Hello!"), AIMessage("How can I assist?")]
# >>> _convert_messages_to_dict(messages)
# [
# {"role": "system", "content": "System initializing..."},
# {"role": "user", "content": "Hello!"},
# {"role": "assistant", "content": "How can I assist?"}
# ]
# """

# human_key = column_remap.get("HumanMessage", "user")
# ai_message_key = column_remap.get("AIMessage", "assistant")
# role_key = column_remap.get("role", "role")
# text_key = column_remap.get("text", "content")
# image_key = column_remap.get("image", "media_url")

# # Mapping of message types to roles
# message_type_to_role = {
# SystemMessage: "system",
# HumanMessage: human_key,
# AIMessage: ai_message_key,
# }

# def convert_format_vision(message_content, role, text_key, image_key):
# result = {}
# result["type"] = role
# for item in message_content:
# if item["type"] == "text":
# result[text_key] = item["text"]
# elif item["type"] == "image_url":
# result[image_key] = item["image_url"]
# return result

# chat = []
# for message in messages:
# message_role = message_type_to_role.get(type(message))
# if message_role:
# if isinstance(message.content, str):
# chat.append({role_key: message_role, text_key: message.content})
# else:
# chat.append(
# convert_format_vision(message.content, message_role, text_key, image_key)
# )
# else:
# raise ValueError(f"Message type {type(message)} not supported")

# return chat


def _prepend_system_to_first_user(messages, column_remap={}):
# Initialize an index for the system message
Expand Down
2 changes: 1 addition & 1 deletion src/agentlab/llm/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def retry(
"""
tries = 0
while tries < n_retry:
answer = chat.invoke(messages)
answer = chat(messages)
messages.append(answer) # TODO: could we change this to not use inplace modifications ?

try:
Expand Down
10 changes: 2 additions & 8 deletions tests/agents/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class CheatMiniWoBLLM_ParseRetry:
n_retry: int
retry_count: int = 0

def invoke(self, messages) -> str:
def __call__(self, messages) -> str:
if self.retry_count < self.n_retry:
self.retry_count += 1
return dict(role="assistant", content="I'm retrying")
Expand All @@ -71,9 +71,6 @@ def invoke(self, messages) -> str:
"""
return dict(role="assistant", content=answer)

def __call__(self, messages) -> str:
return self.invoke(messages)

def get_stats(self):
return {}

Expand All @@ -94,7 +91,7 @@ class CheatLLM_LLMError:
n_retry: int = 0
success: bool = False

def invoke(self, messages) -> str:
def __call__(self, messages) -> str:
if self.success:
prompt = messages[1].get("content", "")
match = re.search(r"^\s*\[(\d+)\].*button", prompt, re.MULTILINE | re.IGNORECASE)
Expand All @@ -113,9 +110,6 @@ def invoke(self, messages) -> str:
return dict(role="assistant", content=answer)
raise OpenAIError("LLM failed to respond")

def __call__(self, messages) -> str:
return self.invoke(messages)

def get_stats(self):
return {"n_llm_retry": self.n_retry, "n_llm_busted_retry": int(not self.success)}

Expand Down
4 changes: 2 additions & 2 deletions tests/llm/test_chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_api_model_args_azure():
make_system_message("You are an helpful virtual assistant"),
make_user_message("Give the third prime number"),
]
answer = model.invoke(messages)
answer = model(messages)

assert "5" in answer.get("content")

Expand All @@ -56,6 +56,6 @@ def test_api_model_args_openai():
make_system_message("You are an helpful virtual assistant"),
make_user_message("Give the third prime number"),
]
answer = model.invoke(messages)
answer = model(messages)

assert "5" in answer.get("content")
Loading