From 352cc8017534264dec5b3a43536ecc53defc5d97 Mon Sep 17 00:00:00 2001 From: Abhijeetsingh Meena Date: Thu, 17 Apr 2025 22:30:06 +0530 Subject: [PATCH 01/10] Add `model_context` parameter to `SelectorGroupChat` for dynamic speaker selection (#6301) Signed-off-by: Abhijeetsingh Meena --- .../teams/_group_chat/_selector_group_chat.py | 57 +++++++++++++++---- 1 file changed, 45 insertions(+), 12 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 2a7b15889ec3..bf43294bbd27 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -4,8 +4,18 @@ from inspect import iscoroutinefunction from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast -from autogen_core import AgentRuntime, Component, ComponentModel -from autogen_core.models import AssistantMessage, ChatCompletionClient, ModelFamily, SystemMessage, UserMessage +from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel +from autogen_core.model_context import ( + ChatCompletionContext, + UnboundedChatCompletionContext, +) +from autogen_core.models import ( + AssistantMessage, + ChatCompletionClient, + ModelFamily, + SystemMessage, + UserMessage, +) from pydantic import BaseModel from typing_extensions import Self @@ -56,6 +66,7 @@ def __init__( max_selector_attempts: int, candidate_func: Optional[CandidateFuncType], emit_team_events: bool, + model_context: ChatCompletionContext | None, ) -> None: super().__init__( name, @@ -79,6 +90,11 @@ def __init__( self._max_selector_attempts = max_selector_attempts self._candidate_func = candidate_func self._is_candidate_func_async = iscoroutinefunction(self._candidate_func) + if model_context is not None: + self._model_context = model_context + else: + self._model_context = UnboundedChatCompletionContext() + self._cancellation_token = CancellationToken() async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: pass @@ -153,16 +169,7 @@ async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) - assert len(participants) > 0 # Construct the history of the conversation. - history_messages: List[str] = [] - for msg in thread: - if not isinstance(msg, BaseChatMessage): - # Only process chat messages. - continue - message = f"{msg.source}: {msg.to_model_text()}" - history_messages.append( - message.rstrip() + "\n\n" - ) # Create some consistency for how messages are separated in the transcript - history = "\n".join(history_messages) + history = self.construct_message_history(thread) # Construct agent roles. # Each agent sould appear on a single line. @@ -180,10 +187,33 @@ async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) - trace_logger.debug(f"Selected speaker: {agent_name}") return agent_name + def construct_message_history( + self, message_history: Sequence[Union[BaseChatMessage, BaseAgentEvent, UserMessage, AssistantMessage]] + ) -> str: + # Construct the history of the conversation. + history_messages: List[str] = [] + for msg in message_history: + if isinstance(msg, BaseChatMessage): + message = f"{msg.source}: {msg.to_model_text()}" + history_messages.append(message.rstrip() + "\n\n") + elif isinstance(msg, UserMessage) or isinstance(msg, AssistantMessage): + message = f"{msg.source}: {msg.content}" + history_messages.append( + message.rstrip() + "\n\n" + ) # Create some consistency for how messages are separated in the transcript + + history: str = "\n".join(history_messages) + return history + async def _select_speaker(self, roles: str, participants: List[str], history: str, max_attempts: int) -> str: + model_context_messages = await self._model_context.get_messages() + model_context_history = self.construct_message_history(model_context_messages) # type: ignore + history = model_context_history + history + select_speaker_prompt = self._selector_prompt.format( roles=roles, participants=str(participants), history=history ) + select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage] if ModelFamily.is_openai(self._model_client.model_info["family"]): select_speaker_messages = [SystemMessage(content=select_speaker_prompt)] @@ -453,6 +483,7 @@ def __init__( candidate_func: Optional[CandidateFuncType] = None, custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, emit_team_events: bool = False, + model_context: ChatCompletionContext | None = None, ): super().__init__( participants, @@ -473,6 +504,7 @@ def __init__( self._selector_func = selector_func self._max_selector_attempts = max_selector_attempts self._candidate_func = candidate_func + self._model_context = model_context def _create_group_chat_manager_factory( self, @@ -505,6 +537,7 @@ def _create_group_chat_manager_factory( self._max_selector_attempts, self._candidate_func, self._emit_team_events, + self._model_context, ) def _to_config(self) -> SelectorGroupChatConfig: From d839ec38014cbd719f10c080f13b2805846a71ae Mon Sep 17 00:00:00 2001 From: Abhijeetsingh Meena Date: Mon, 21 Apr 2025 21:52:46 +0530 Subject: [PATCH 02/10] Use `model_context` to select next speaker --- .../teams/_group_chat/_selector_group_chat.py | 38 ++++++++++++++----- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index bf43294bbd27..57e4bb8421a7 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -6,8 +6,8 @@ from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel from autogen_core.model_context import ( + BufferedChatCompletionContext, ChatCompletionContext, - UnboundedChatCompletionContext, ) from autogen_core.models import ( AssistantMessage, @@ -25,6 +25,7 @@ from ...messages import ( BaseAgentEvent, BaseChatMessage, + HandoffMessage, MessageFactory, ) from ...state import SelectorManagerState @@ -93,7 +94,8 @@ def __init__( if model_context is not None: self._model_context = model_context else: - self._model_context = UnboundedChatCompletionContext() + # TODO: finalize the best default context class + self._model_context = BufferedChatCompletionContext(buffer_size=5) self._cancellation_token = CancellationToken() async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: @@ -102,6 +104,7 @@ async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> async def reset(self) -> None: self._current_turn = 0 self._message_thread.clear() + await self._model_context.clear() if self._termination_condition is not None: await self._termination_condition.reset() self._previous_speaker = None @@ -117,15 +120,36 @@ async def save_state(self) -> Mapping[str, Any]: async def load_state(self, state: Mapping[str, Any]) -> None: selector_state = SelectorManagerState.model_validate(state) self._message_thread = [self._message_factory.create(msg) for msg in selector_state.message_thread] + await self._add_messages_to_context( + self._model_context, [msg for msg in self._message_thread if isinstance(msg, BaseChatMessage)] + ) self._current_turn = selector_state.current_turn self._previous_speaker = selector_state.previous_speaker + @staticmethod + async def _add_messages_to_context( + model_context: ChatCompletionContext, + messages: Sequence[BaseChatMessage], + ) -> None: + """ + Add incoming messages to the model context. + """ + for msg in messages: + if isinstance(msg, HandoffMessage): + for llm_msg in msg.context: + await model_context.add_message(llm_msg) + await model_context.add_message(msg.to_model_message()) + async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str: """Selects the next speaker in a group chat using a ChatCompletion client, with the selector function as override if it returns a speaker name. A key assumption is that the agent type is the same as the topic type, which we use as the agent name. """ + # TODO: A hacky solution - Update model context from _message_thread at every speaker selection + # Add last BaseChatMessage to model context + if isinstance(thread[-1], BaseChatMessage): + await self._model_context.add_message(thread[-1].to_model_message()) # Use the selector function if provided. if self._selector_func is not None: @@ -168,9 +192,6 @@ async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) - assert len(participants) > 0 - # Construct the history of the conversation. - history = self.construct_message_history(thread) - # Construct agent roles. # Each agent sould appear on a single line. roles = "" @@ -180,7 +201,7 @@ async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) - # Select the next speaker. if len(participants) > 1: - agent_name = await self._select_speaker(roles, participants, history, self._max_selector_attempts) + agent_name = await self._select_speaker(roles, participants, self._max_selector_attempts) else: agent_name = participants[0] self._previous_speaker = agent_name @@ -205,13 +226,12 @@ def construct_message_history( history: str = "\n".join(history_messages) return history - async def _select_speaker(self, roles: str, participants: List[str], history: str, max_attempts: int) -> str: + async def _select_speaker(self, roles: str, participants: List[str], max_attempts: int) -> str: model_context_messages = await self._model_context.get_messages() model_context_history = self.construct_message_history(model_context_messages) # type: ignore - history = model_context_history + history select_speaker_prompt = self._selector_prompt.format( - roles=roles, participants=str(participants), history=history + roles=roles, participants=str(participants), history=model_context_history ) select_speaker_messages: List[SystemMessage | UserMessage | AssistantMessage] From b37ec18d132aca16844429f3c825b54b2fc0fc46 Mon Sep 17 00:00:00 2001 From: Abhijeetsingh Meena Date: Tue, 22 Apr 2025 22:42:24 +0530 Subject: [PATCH 03/10] Use `UnboundedChatCompletionContext` as default chat completion context --- .../teams/_group_chat/_selector_group_chat.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 57e4bb8421a7..347c0f487a39 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -6,7 +6,7 @@ from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel from autogen_core.model_context import ( - BufferedChatCompletionContext, + UnboundedChatCompletionContext, ChatCompletionContext, ) from autogen_core.models import ( @@ -94,8 +94,7 @@ def __init__( if model_context is not None: self._model_context = model_context else: - # TODO: finalize the best default context class - self._model_context = BufferedChatCompletionContext(buffer_size=5) + self._model_context = UnboundedChatCompletionContext() self._cancellation_token = CancellationToken() async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: From 4f0b015fb6572eed1b825c8c47ef65b904b8c8dd Mon Sep 17 00:00:00 2001 From: Abhijeetsingh Meena Date: Tue, 22 Apr 2025 23:18:51 +0530 Subject: [PATCH 04/10] Refactor message thread updates with `update_message_thread` method - Added `update_message_thread` method in `BaseGroupChatManager` to manage message thread updates. - Replaced direct `_message_thread` modifications with calls to this method. - Overrode `update_message_thread` in `SelectorGroupChat` to also update the `model_context`. Signed-off-by: Abhijeetsingh Meena --- .../teams/_group_chat/_base_group_chat_manager.py | 9 ++++++--- .../_magentic_one/_magentic_one_orchestrator.py | 8 ++++---- .../teams/_group_chat/_selector_group_chat.py | 12 ++++++------ 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index afd3407620b7..4c459f367416 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -115,7 +115,7 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No ) # Append all messages to thread - self._message_thread.extend(message.messages) + await self.update_message_thread(message.messages) # Check termination condition after processing all messages if await self._apply_termination_condition(message.messages): @@ -139,6 +139,9 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No cancellation_token=ctx.cancellation_token, ) + async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None: + self._message_thread.extend(messages) + @event async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None: try: @@ -146,9 +149,9 @@ async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: Mess delta: List[BaseAgentEvent | BaseChatMessage] = [] if message.agent_response.inner_messages is not None: for inner_message in message.agent_response.inner_messages: - self._message_thread.append(inner_message) + await self.update_message_thread([inner_message]) delta.append(inner_message) - self._message_thread.append(message.agent_response.chat_message) + await self.update_message_thread([message.agent_response.chat_message]) delta.append(message.agent_response.chat_message) # Check if the conversation should be terminated. diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py index 78bf3929a046..34d1df7cf948 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py @@ -191,7 +191,7 @@ async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: Mess if message.agent_response.inner_messages is not None: for inner_message in message.agent_response.inner_messages: delta.append(inner_message) - self._message_thread.append(message.agent_response.chat_message) + await self.update_message_thread([message.agent_response.chat_message]) delta.append(message.agent_response.chat_message) if self._termination_condition is not None: @@ -263,7 +263,7 @@ async def _reenter_outer_loop(self, cancellation_token: CancellationToken) -> No ) # Save my copy - self._message_thread.append(ledger_message) + await self.update_message_thread([ledger_message]) # Log it to the output topic. await self.publish_message( @@ -376,7 +376,7 @@ async def _orchestrate_step(self, cancellation_token: CancellationToken) -> None # Broadcast the next step message = TextMessage(content=progress_ledger["instruction_or_question"]["answer"], source=self._name) - self._message_thread.append(message) # My copy + await self.update_message_thread([message]) # My copy await self._log_message(f"Next Speaker: {progress_ledger['next_speaker']['answer']}") # Log it to the output topic. @@ -458,7 +458,7 @@ async def _prepare_final_answer(self, reason: str, cancellation_token: Cancellat assert isinstance(response.content, str) message = TextMessage(content=response.content, source=self._name) - self._message_thread.append(message) # My copy + await self.update_message_thread([message]) # My copy # Log it to the output topic. await self.publish_message( diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 347c0f487a39..9680c216b4d1 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -6,8 +6,8 @@ from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel from autogen_core.model_context import ( - UnboundedChatCompletionContext, ChatCompletionContext, + UnboundedChatCompletionContext, ) from autogen_core.models import ( AssistantMessage, @@ -139,17 +139,17 @@ async def _add_messages_to_context( await model_context.add_message(llm_msg) await model_context.add_message(msg.to_model_message()) + async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None: + self._message_thread.extend(messages) + base_chat_messages = [m for m in messages if isinstance(m, BaseChatMessage)] + await self._add_messages_to_context(self._model_context, base_chat_messages) + async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) -> str: """Selects the next speaker in a group chat using a ChatCompletion client, with the selector function as override if it returns a speaker name. A key assumption is that the agent type is the same as the topic type, which we use as the agent name. """ - # TODO: A hacky solution - Update model context from _message_thread at every speaker selection - # Add last BaseChatMessage to model context - if isinstance(thread[-1], BaseChatMessage): - await self._model_context.add_message(thread[-1].to_model_message()) - # Use the selector function if provided. if self._selector_func is not None: if self._is_selector_func_async: From 3bc42fea39d14fb3438eaaaabe5b48070895851a Mon Sep 17 00:00:00 2001 From: Abhijeetsingh Meena Date: Wed, 23 Apr 2025 21:52:49 +0530 Subject: [PATCH 05/10] Update message thread in one call using delta --- .../teams/_group_chat/_base_group_chat_manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index 4c459f367416..1ebe658c18e4 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -149,10 +149,9 @@ async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: Mess delta: List[BaseAgentEvent | BaseChatMessage] = [] if message.agent_response.inner_messages is not None: for inner_message in message.agent_response.inner_messages: - await self.update_message_thread([inner_message]) delta.append(inner_message) - await self.update_message_thread([message.agent_response.chat_message]) delta.append(message.agent_response.chat_message) + await self.update_message_thread(delta) # Check if the conversation should be terminated. if await self._apply_termination_condition(delta, increment_turn_count=True): From 35e1da931aba91430b7e3df85fac4495b2893aff Mon Sep 17 00:00:00 2001 From: Abhijeetsingh Meena Date: Wed, 23 Apr 2025 21:55:45 +0530 Subject: [PATCH 06/10] Add unit test for in SelectorGroupChat --- .../tests/test_group_chat.py | 79 +++++++++++++++---- 1 file changed, 62 insertions(+), 17 deletions(-) diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 947d9595ba97..f920bafca933 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -2,10 +2,27 @@ import json import logging import tempfile -from typing import Any, AsyncGenerator, List, Mapping, Sequence +from typing import Any, AsyncGenerator, Dict, List, Mapping, Sequence import pytest import pytest_asyncio +from autogen_core import AgentId, AgentRuntime, CancellationToken, FunctionCall, SingleThreadedAgentRuntime +from autogen_core.models import ( + AssistantMessage, + CreateResult, + FunctionExecutionResult, + FunctionExecutionResultMessage, + LLMMessage, + RequestUsage, + UserMessage, +) +from autogen_core.tools import FunctionTool +from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor +from autogen_ext.models.openai import OpenAIChatCompletionClient +from autogen_ext.models.replay import ReplayChatCompletionClient +from pydantic import BaseModel +from utils import FileLogHandler + from autogen_agentchat import EVENT_LOGGER_NAME from autogen_agentchat.agents import ( AssistantAgent, @@ -32,22 +49,6 @@ from autogen_agentchat.teams._group_chat._selector_group_chat import SelectorGroupChatManager from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager from autogen_agentchat.ui import Console -from autogen_core import AgentId, AgentRuntime, CancellationToken, FunctionCall, SingleThreadedAgentRuntime -from autogen_core.models import ( - AssistantMessage, - CreateResult, - FunctionExecutionResult, - FunctionExecutionResultMessage, - LLMMessage, - RequestUsage, - UserMessage, -) -from autogen_core.tools import FunctionTool -from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor -from autogen_ext.models.openai import OpenAIChatCompletionClient -from autogen_ext.models.replay import ReplayChatCompletionClient -from pydantic import BaseModel -from utils import FileLogHandler logger = logging.getLogger(EVENT_LOGGER_NAME) logger.setLevel(logging.DEBUG) @@ -691,6 +692,50 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None: assert result2 == result +@pytest.mark.asyncio +async def test_selector_group_chat_with_model_context(runtime: AgentRuntime | None) -> None: + selector_group_chat_model_client = ReplayChatCompletionClient(["agent2", "agent1", "agent1", "agent2"]) + agent_one_model_client = ReplayChatCompletionClient( + ["[Agent One] First generation", "[Agent One] Second generation", "TERMINATE"] + ) + agent_two_model_client = ReplayChatCompletionClient( + ["[Agent Two] First generation", "[Agent Two] Second generation", "TERMINATE"] + ) + + agent1 = AssistantAgent("agent1", model_client=agent_one_model_client, description="Assistant agent 1") + agent2 = AssistantAgent("agent2", model_client=agent_two_model_client, description="Assistant agent 2") + + termination = TextMentionTermination("TERMINATE") + team = SelectorGroupChat( + participants=[agent1, agent2], + model_client=selector_group_chat_model_client, + termination_condition=termination, + runtime=runtime, + emit_team_events=True, + ) + await team.run( + task="[GroupChat] Task", + ) + + messages_to_check = { + "1": "user: [GroupChat] Task", + "2": "agent2: [Agent Two] First generation", + "3": "agent1: [Agent One] First generation", + "4": "agent1: [Agent One] Second generation", + "5": "agent2: [Agent Two] Second generation", + } + + create_calls: List[Dict[str, Any]] = selector_group_chat_model_client.create_calls + for idx, call in enumerate(create_calls): + messages = call["messages"] + prompt = messages[0].content + prompt_lines = prompt.split("\n") + chat_history = [value for _, value in list(messages_to_check.items())[: idx + 1]] + assert all( + line.strip() in prompt_lines for line in chat_history + ), f"Expected all lines {chat_history} to be in prompt, but got {prompt_lines}" + + @pytest.mark.asyncio async def test_selector_group_chat_with_team_event(runtime: AgentRuntime | None) -> None: model_client = ReplayChatCompletionClient( From 2580acfccd75b2b3fc82bfabd528320fb82efbad Mon Sep 17 00:00:00 2001 From: Abhijeetsingh Meena Date: Thu, 24 Apr 2025 20:28:31 +0530 Subject: [PATCH 07/10] Refactor message history construction to use LLMMessage Signed-off-by: Abhijeetsingh Meena --- .../teams/_group_chat/_selector_group_chat.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 9680c216b4d1..3b9020b4e6ab 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -12,6 +12,7 @@ from autogen_core.models import ( AssistantMessage, ChatCompletionClient, + LLMMessage, ModelFamily, SystemMessage, UserMessage, @@ -207,16 +208,11 @@ async def select_speaker(self, thread: List[BaseAgentEvent | BaseChatMessage]) - trace_logger.debug(f"Selected speaker: {agent_name}") return agent_name - def construct_message_history( - self, message_history: Sequence[Union[BaseChatMessage, BaseAgentEvent, UserMessage, AssistantMessage]] - ) -> str: + def construct_message_history(self, message_history: Sequence[LLMMessage]) -> str: # Construct the history of the conversation. history_messages: List[str] = [] for msg in message_history: - if isinstance(msg, BaseChatMessage): - message = f"{msg.source}: {msg.to_model_text()}" - history_messages.append(message.rstrip() + "\n\n") - elif isinstance(msg, UserMessage) or isinstance(msg, AssistantMessage): + if isinstance(msg, UserMessage) or isinstance(msg, AssistantMessage): message = f"{msg.source}: {msg.content}" history_messages.append( message.rstrip() + "\n\n" From 63212efddd58ed719c43dd803c96d37ff079127b Mon Sep 17 00:00:00 2001 From: Abhijeetsingh Meena Date: Thu, 1 May 2025 22:18:43 +0530 Subject: [PATCH 08/10] Add unit test to check `model_context` parameter in `SelectorGroupChat` Signed-off-by: Abhijeetsingh Meena --- .../tests/test_group_chat.py | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index f920bafca933..ed40294d510d 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -7,6 +7,7 @@ import pytest import pytest_asyncio from autogen_core import AgentId, AgentRuntime, CancellationToken, FunctionCall, SingleThreadedAgentRuntime +from autogen_core.model_context import BufferedChatCompletionContext from autogen_core.models import ( AssistantMessage, CreateResult, @@ -694,12 +695,17 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None: @pytest.mark.asyncio async def test_selector_group_chat_with_model_context(runtime: AgentRuntime | None) -> None: - selector_group_chat_model_client = ReplayChatCompletionClient(["agent2", "agent1", "agent1", "agent2"]) + buffered_context = BufferedChatCompletionContext(buffer_size=5) + await buffered_context.add_message(UserMessage(content="[User] Prefilled message", source="user")) + + selector_group_chat_model_client = ReplayChatCompletionClient( + ["agent2", "agent1", "agent1", "agent2", "agent1", "agent2", "agent1"] + ) agent_one_model_client = ReplayChatCompletionClient( - ["[Agent One] First generation", "[Agent One] Second generation", "TERMINATE"] + ["[Agent One] First generation", "[Agent One] Second generation", "[Agent One] Third generation", "TERMINATE"] ) agent_two_model_client = ReplayChatCompletionClient( - ["[Agent Two] First generation", "[Agent Two] Second generation", "TERMINATE"] + ["[Agent Two] First generation", "[Agent Two] Second generation", "[Agent Two] Third generation"] ) agent1 = AssistantAgent("agent1", model_client=agent_one_model_client, description="Assistant agent 1") @@ -712,25 +718,30 @@ async def test_selector_group_chat_with_model_context(runtime: AgentRuntime | No termination_condition=termination, runtime=runtime, emit_team_events=True, + allow_repeated_speaker=True, + model_context=buffered_context, ) await team.run( task="[GroupChat] Task", ) - messages_to_check = { - "1": "user: [GroupChat] Task", - "2": "agent2: [Agent Two] First generation", - "3": "agent1: [Agent One] First generation", - "4": "agent1: [Agent One] Second generation", - "5": "agent2: [Agent Two] Second generation", - } + messages_to_check = [ + "user: [User] Prefilled message", + "user: [GroupChat] Task", + "agent2: [Agent Two] First generation", + "agent1: [Agent One] First generation", + "agent1: [Agent One] Second generation", + "agent2: [Agent Two] Second generation", + "agent1: [Agent One] Third generation", + "agent2: [Agent Two] Third generation", + ] create_calls: List[Dict[str, Any]] = selector_group_chat_model_client.create_calls for idx, call in enumerate(create_calls): messages = call["messages"] prompt = messages[0].content prompt_lines = prompt.split("\n") - chat_history = [value for _, value in list(messages_to_check.items())[: idx + 1]] + chat_history = [value for value in messages_to_check[max(0, idx - 3) : idx + 2]] assert all( line.strip() in prompt_lines for line in chat_history ), f"Expected all lines {chat_history} to be in prompt, but got {prompt_lines}" From 26f6de008e22730396e825188932b30ebe6e86a4 Mon Sep 17 00:00:00 2001 From: Abhijeetsingh Meena Date: Thu, 1 May 2025 22:31:44 +0530 Subject: [PATCH 09/10] Update API documentation to include `model_context` usage in `SelectorGroupChat` Signed-off-by: Abhijeetsingh Meena --- .../teams/_group_chat/_selector_group_chat.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 3b9020b4e6ab..8f48e28f8980 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -356,6 +356,8 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`. This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set. emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False. + model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving + :class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. Messages stored in model context will be used for speaker selection. The initial messages will be cleared when the team is reset. Raises: ValueError: If the number of participants is less than two or if the selector prompt is invalid. @@ -470,6 +472,64 @@ def selector_func(messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> str | await Console(team.run_stream(task="What is 1 + 1?")) + asyncio.run(main()) + + A team with custom model context: + + .. code-block:: python + + import asyncio + + from autogen_core.model_context import BufferedChatCompletionContext + from autogen_ext.models.openai import OpenAIChatCompletionClient + + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.conditions import TextMentionTermination + from autogen_agentchat.teams import SelectorGroupChat + from autogen_agentchat.ui import Console + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o") + model_context = BufferedChatCompletionContext(buffer_size=5) + + async def lookup_hotel(location: str) -> str: + return f"Here are some hotels in {location}: hotel1, hotel2, hotel3." + + async def lookup_flight(origin: str, destination: str) -> str: + return f"Here are some flights from {origin} to {destination}: flight1, flight2, flight3." + + async def book_trip() -> str: + return "Your trip is booked!" + + travel_advisor = AssistantAgent( + "Travel_Advisor", + model_client, + tools=[book_trip], + description="Helps with travel planning.", + ) + hotel_agent = AssistantAgent( + "Hotel_Agent", + model_client, + tools=[lookup_hotel], + description="Helps with hotel booking.", + ) + flight_agent = AssistantAgent( + "Flight_Agent", + model_client, + tools=[lookup_flight], + description="Helps with flight booking.", + ) + termination = TextMentionTermination("TERMINATE") + team = SelectorGroupChat( + [travel_advisor, hotel_agent, flight_agent], + model_client=model_client, + termination_condition=termination, + model_context=model_context, + ) + await Console(team.run_stream(task="Book a 3-day trip to new york.")) + + asyncio.run(main()) """ From 57dbeaa148223021489ea3cfe9bed3f9aa7310b7 Mon Sep 17 00:00:00 2001 From: Abhijeetsingh Meena Date: Thu, 1 May 2025 22:57:30 +0530 Subject: [PATCH 10/10] Add `model_context` to `SelectorGroupChatConfig` Signed-off-by: Abhijeetsingh Meena --- .../teams/_group_chat/_selector_group_chat.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 21a490ac6d3d..4e4347cb9929 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -71,8 +71,8 @@ def __init__( max_selector_attempts: int, candidate_func: Optional[CandidateFuncType], emit_team_events: bool, - model_client_streaming: bool = False, model_context: ChatCompletionContext | None, + model_client_streaming: bool = False, ) -> None: super().__init__( name, @@ -351,6 +351,7 @@ class SelectorGroupChatConfig(BaseModel): max_selector_attempts: int = 3 emit_team_events: bool = False model_client_streaming: bool = False + model_context: ComponentModel | None = None class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]): @@ -646,8 +647,8 @@ def _create_group_chat_manager_factory( self._max_selector_attempts, self._candidate_func, self._emit_team_events, - self._model_client_streaming, self._model_context, + self._model_client_streaming, ) def _to_config(self) -> SelectorGroupChatConfig: @@ -662,6 +663,7 @@ def _to_config(self) -> SelectorGroupChatConfig: # selector_func=self._selector_func.dump_component() if self._selector_func else None, emit_team_events=self._emit_team_events, model_client_streaming=self._model_client_streaming, + model_context=self._model_context.dump_component() if self._model_context else None, ) @classmethod @@ -681,4 +683,5 @@ def _from_config(cls, config: SelectorGroupChatConfig) -> Self: # else None, emit_team_events=config.emit_team_events, model_client_streaming=config.model_client_streaming, + model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None, )