-
Notifications
You must be signed in to change notification settings - Fork 54
Feat/38 firestore memory #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e7c3782
0eb1432
0e38fa1
0cb21e7
5a27fb8
120ab5a
a90bf4a
2ff0d6d
8d3244a
8897938
4d10aa5
b51ae6f
8ed069a
af1f3a0
e13f1e5
fbe5c6d
7059163
23b7d7f
62e6183
5088dc8
519bcf0
3fd0ca0
a755bfb
1088007
2c7e8dd
e9465d2
5eff912
4940c4e
8d91b55
eae9de8
de1b5f7
b46711f
a47ca4f
d2dfd70
2f8c931
7f2262e
1d15798
bf0e938
6d23431
6ef25ec
ec903e2
60b1845
05f941d
05e1b53
5fdeba2
560941a
e3a9a18
09b811f
7850f75
2e1e3bc
fa6f7e0
9c22920
bfabcad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,304 @@ | ||
| # Copyright 2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import json | ||
| import logging | ||
| from typing import TYPE_CHECKING, Any, Optional | ||
| from typing_extensions import override | ||
|
|
||
| import google.auth | ||
| from google.cloud import firestore | ||
| from google.genai import types | ||
|
|
||
| from google.adk.agents.llm_agent import Agent | ||
| from google.adk.models.llm_request import LlmRequest | ||
| from google.adk.utils.context_utils import Aclosing | ||
| from google.adk.memory import _utils | ||
| from google.adk.memory.base_memory_service import BaseMemoryService | ||
| from google.adk.memory.base_memory_service import SearchMemoryResponse | ||
| from google.adk.memory.memory_entry import MemoryEntry | ||
|
|
||
| if TYPE_CHECKING: | ||
| from google.adk.sessions.session import Session | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class FirestoreLLMMemoryService(BaseMemoryService): | ||
| """A Firestore-based memory service that uses an LLM to manage facts. | ||
|
|
||
| Instead of storing raw events, it extracts and reconciles concise facts | ||
| about the user, enabling smarter semantic search and memory management. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| collection_name: str = "agent_facts", | ||
| model: str = "gemini-2.0-flash", | ||
| database: Optional[ | ||
| str | ||
| ] = "(default)", # use generous free tier default by default | ||
| fact_limit: int = 100, | ||
| ): | ||
| """Initializes the FirestoreLLMMemoryService. | ||
|
|
||
| Args: | ||
| collection_name: The root collection name in Firestore. | ||
| model: The LLM model to use for memory management. | ||
| database: The Firestore database to use. (Uses free tier by default) | ||
|
jeffbryner marked this conversation as resolved.
|
||
| fact_limit: The maximum number of recent facts to store, search and reconcile. | ||
|
|
||
| """ | ||
| credentials, project_id = google.auth.default() | ||
| self.db = firestore.AsyncClient( | ||
| credentials=credentials, project=project_id, database=database | ||
| ) | ||
| self.collection_name = collection_name | ||
|
|
||
| # limit for storing, reconciling, and searching facts | ||
| self._fact_limit = fact_limit | ||
|
|
||
| # The internal agent dedicated to managing the memory state. | ||
| self._memory_agent = Agent( | ||
| model=model, | ||
| name="memory_manager", | ||
| description="Manages user facts and retrieves relevant memories.", | ||
| instruction=( | ||
| "You are a memory management assistant. Your job is to maintain a high-quality " | ||
| "list of facts about conversations. " | ||
| "You will be asked to reconcile new conversations with existing facts " | ||
| "and to retrieve relevant facts based on queries." | ||
| ), | ||
| ) | ||
|
|
||
| def _format_session(self, session: Session) -> str: | ||
| lines = [] | ||
| for event in session.events: | ||
| if not event.content or not event.content.parts: | ||
| continue | ||
| text = " ".join([p.text for p in event.content.parts if p.text]) | ||
| lines.append(f"{event.author}: {text}") | ||
| return "\n".join(lines) | ||
|
|
||
| async def _call_agent(self, prompt: str) -> str: | ||
| """Utility to call the underlying LLM of the agent.""" | ||
| llm = self._memory_agent.canonical_model | ||
| request = LlmRequest(model=llm.model) | ||
|
|
||
| # Add system instruction from agent | ||
| if isinstance(self._memory_agent.instruction, str): | ||
| request.append_instructions([self._memory_agent.instruction]) | ||
|
|
||
| # Add user prompt | ||
| request.contents.append( | ||
| types.Content(role="user", parts=[types.Part(text=prompt)]) | ||
| ) | ||
|
|
||
| async with Aclosing(llm.generate_content_async(request)) as agen: | ||
| response_text_parts = [] | ||
| async for response in agen: | ||
| if response.content: | ||
| for part in response.content.parts: | ||
| if part.text: | ||
| response_text_parts.append(part.text) | ||
|
|
||
| return "".join(response_text_parts) | ||
|
jeffbryner marked this conversation as resolved.
|
||
|
|
||
| def _parse_llm_json_response(self, content: str) -> Any | None: | ||
| """Utility to strip markdown and parse JSON from LLM response, with error logging.""" | ||
| if not content: | ||
| return None | ||
| try: | ||
| content = content.strip() | ||
| if content.startswith("```json"): | ||
| content = content.removeprefix("```json").removesuffix("```").strip() | ||
| elif content.startswith("```"): | ||
| content = content.removeprefix("```").removesuffix("```").strip() | ||
| return json.loads(content) | ||
| except json.JSONDecodeError as e: | ||
| logger.error( | ||
| f"Failed to parse Agent JSON response {e}. Response: {content}" | ||
| ) | ||
| return None | ||
|
|
||
| @override | ||
| async def add_session_to_memory(self, session: Session): | ||
| """Extracts facts from the session and updates Firestore.""" | ||
| user_key = f"{session.app_name}:{session.user_id}" | ||
| # reference to the facts subcollection | ||
| facts_collection_ref = ( | ||
| self.db.collection(self.collection_name) | ||
| .document(user_key) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using a limit here might cause the LLM to miss older facts that are still relevant to the new session, potentially leading to duplicate facts being created. Consider if there's a way to provide more context or if this limitation is acceptable for the use case. |
||
| .collection("facts") | ||
| ) | ||
| # get a subset of existing facts to reconcile against | ||
| facts_ref = facts_collection_ref.order_by( | ||
| "timestamp", direction=firestore.Query.DESCENDING | ||
| ).limit(self._fact_limit) | ||
|
|
||
| # 1. Fetch existing facts | ||
| existing_facts = [] | ||
| async for doc in facts_ref.stream(): | ||
|
jeffbryner marked this conversation as resolved.
|
||
| data = doc.to_dict() | ||
| if data and "text" in data: | ||
| existing_facts.append({"id": doc.id, "text": data["text"]}) | ||
|
|
||
| # 2. Reconcile with the Agent | ||
| session_text = self._format_session(session) | ||
| prompt = ( | ||
| f"Existing Facts:\n{json.dumps(existing_facts, indent=2)}\n\n" | ||
| f"New Session Transcript:\n{session_text}\n\n" | ||
| "Task: Reconcile the new session with the existing facts. " | ||
| "Identify facts to add, update, or delete. " | ||
| "Respond ONLY with a JSON object in this format:\n" | ||
| '{"add": ["new fact text"], "update": [{"id": "doc_id", "text": "new text"}], "delete": ["doc_id"]}' | ||
| ) | ||
|
|
||
| content = await self._call_agent(prompt) | ||
| operations = self._parse_llm_json_response(content) | ||
| if not operations or not isinstance(operations, dict): | ||
| logger.warning( | ||
| f"No valid operations returned from Agent for session {session.id}." | ||
| ) | ||
| return | ||
|
|
||
| # 3. Apply operations to Firestore in chunks of 500 (Firestore limit) | ||
| all_ops = [] | ||
|
|
||
| for fact_text in operations.get("add", []): | ||
| if isinstance(fact_text, str): | ||
| all_ops.append( | ||
| ( | ||
| "SET", | ||
| facts_collection_ref.document(), | ||
| { | ||
| "text": fact_text, | ||
| "timestamp": firestore.SERVER_TIMESTAMP, | ||
| "source_session_id": session.id, | ||
| }, | ||
| ) | ||
| ) | ||
|
|
||
| for update in operations.get("update", []): | ||
| if ( | ||
| isinstance(update, dict) | ||
| and isinstance(update.get("id"), str) | ||
| and isinstance(update.get("text"), str) | ||
| ): | ||
| all_ops.append( | ||
| ( | ||
| "UPDATE", | ||
| facts_collection_ref.document(update["id"]), | ||
| { | ||
| "text": update["text"], | ||
| "timestamp": firestore.SERVER_TIMESTAMP, | ||
| "source_session_id": session.id, | ||
| }, | ||
| ) | ||
| ) | ||
|
|
||
| for doc_id in operations.get("delete", []): | ||
| if isinstance(doc_id, str): | ||
| all_ops.append(("DELETE", facts_collection_ref.document(doc_id), None)) | ||
|
|
||
| # Commit in chunks of 500 | ||
| try: | ||
| for i in range(0, len(all_ops), 500): | ||
| batch = self.db.batch() | ||
| chunk = all_ops[i : i + 500] | ||
| for op_type, doc_ref, data in chunk: | ||
| if op_type == "SET": | ||
| batch.set(doc_ref, data) | ||
| elif op_type == "UPDATE": | ||
| batch.update(doc_ref, data) | ||
| elif op_type == "DELETE": | ||
| batch.delete(doc_ref) | ||
| await batch.commit() | ||
| except Exception as e: | ||
| logger.error(f"Error applying fact operations to Firestore: {e}") | ||
| pass | ||
|
|
||
| @override | ||
| async def search_memory( | ||
| self, *, app_name: str, user_id: str, query: str | ||
| ) -> SearchMemoryResponse: | ||
| """ | ||
| Uses the Agent to find relevant facts based on the query. | ||
| Args: | ||
| app_name: The application name. | ||
| user_id: The user ID. | ||
| query: The search query. | ||
| """ | ||
| user_key = f"{app_name}:{user_id}" | ||
| facts_ref = ( | ||
| self.db.collection(self.collection_name) | ||
| .document(user_key) | ||
| .collection("facts") | ||
| .order_by("timestamp", direction=firestore.Query.DESCENDING) | ||
| .limit(self._fact_limit) | ||
| ) | ||
|
jeffbryner marked this conversation as resolved.
|
||
|
|
||
| # 1. Fetch facts | ||
| # Note: If expecting a large number of facts, | ||
| # consider vector search (would require embedding cloud function, etc.) | ||
| all_facts = [] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to |
||
| async for doc in facts_ref.stream(): | ||
| data = doc.to_dict() | ||
| if not data or not data.get("text"): | ||
| continue | ||
| timestamp_obj = data.get("timestamp") | ||
| timestamp = timestamp_obj.timestamp() if timestamp_obj else 0 | ||
|
jeffbryner marked this conversation as resolved.
jeffbryner marked this conversation as resolved.
jeffbryner marked this conversation as resolved.
|
||
| all_facts.append( | ||
| { | ||
| "id": doc.id, | ||
| "text": data.get("text"), | ||
| "timestamp": timestamp, | ||
| "source_session_id": data.get("source_session_id"), | ||
| } | ||
| ) | ||
|
jeffbryner marked this conversation as resolved.
|
||
|
|
||
| if not all_facts: | ||
| return SearchMemoryResponse() | ||
|
jeffbryner marked this conversation as resolved.
|
||
|
|
||
| # 2. Filter with the Agent | ||
| prompt = ( | ||
| f"User Query: {query}\n\n" | ||
| f"Available Facts:\n{json.dumps(all_facts, indent=2, default=str)}\n\n" | ||
| "Task: Identify which facts are relevant to the user query. " | ||
| "Respond ONLY with a JSON list of IDs of the relevant facts. " | ||
| 'Example: ["id1", "id2"]' | ||
| ) | ||
|
|
||
| content = await self._call_agent(prompt) | ||
| relevant_ids = self._parse_llm_json_response(content) | ||
| if not relevant_ids or not isinstance(relevant_ids, list): | ||
| return SearchMemoryResponse() | ||
|
jeffbryner marked this conversation as resolved.
|
||
|
|
||
| # 3. Construct response | ||
| search_response = SearchMemoryResponse() | ||
| relevant_facts = [ | ||
| f for f in all_facts if f["id"] in set(map(str, relevant_ids)) | ||
| ] | ||
| for fact in relevant_facts: | ||
| search_response.memories.append( | ||
| MemoryEntry( | ||
| content=types.Content(parts=[types.Part(text=fact["text"])]), | ||
| author="memory_manager", | ||
| timestamp=_utils.format_timestamp(fact["timestamp"]), | ||
| ) | ||
| ) | ||
|
|
||
| return search_response | ||
Uh oh!
There was an error while loading. Please reload this page.